1use bytes::{Bytes, BytesMut, Buf as _, BufMut as _};
22use futures::{prelude::*, io::IoSlice};
23use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16};
24
25const MAX_LEN_BYTES: u16 = 2;
26const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
27const DEFAULT_BUFFER_SIZE: usize = 64;
28
29#[pin_project::pin_project]
36#[derive(Debug)]
37pub struct LengthDelimited<R> {
38 #[pin]
40 inner: R,
41 read_buffer: BytesMut,
43 write_buffer: BytesMut,
45 read_state: ReadState,
48}
49
50#[derive(Debug, Copy, Clone, PartialEq, Eq)]
51enum ReadState {
52 ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize },
54 ReadData { len: u16, pos: usize },
56}
57
58impl Default for ReadState {
59 fn default() -> Self {
60 ReadState::ReadLength {
61 buf: [0; MAX_LEN_BYTES as usize],
62 pos: 0
63 }
64 }
65}
66
67impl<R> LengthDelimited<R> {
68 pub fn new(inner: R) -> LengthDelimited<R> {
71 LengthDelimited {
72 inner,
73 read_state: ReadState::default(),
74 read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
75 write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
76 }
77 }
78
79 pub fn into_inner(self) -> R {
88 assert!(self.read_buffer.is_empty());
89 assert!(self.write_buffer.is_empty());
90 self.inner
91 }
92
93 pub fn into_reader(self) -> LengthDelimitedReader<R> {
101 LengthDelimitedReader { inner: self }
102 }
103
104 pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>)
110 -> Poll<Result<(), io::Error>>
111 where
112 R: AsyncWrite
113 {
114 let mut this = self.project();
115
116 while !this.write_buffer.is_empty() {
117 match this.inner.as_mut().poll_write(cx, &this.write_buffer) {
118 Poll::Pending => return Poll::Pending,
119 Poll::Ready(Ok(0)) => {
120 return Poll::Ready(Err(io::Error::new(
121 io::ErrorKind::WriteZero,
122 "Failed to write buffered frame.")))
123 }
124 Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
125 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
126 }
127 }
128
129 Poll::Ready(Ok(()))
130 }
131}
132
133impl<R> Stream for LengthDelimited<R>
134where
135 R: AsyncRead
136{
137 type Item = Result<Bytes, io::Error>;
138
139 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 let mut this = self.project();
141
142 loop {
143 match this.read_state {
144 ReadState::ReadLength { buf, pos } => {
145 match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) {
146 Poll::Ready(Ok(0)) => {
147 if *pos == 0 {
148 return Poll::Ready(None);
149 } else {
150 return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
151 }
152 }
153 Poll::Ready(Ok(n)) => {
154 debug_assert_eq!(n, 1);
155 *pos += n;
156 }
157 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
158 Poll::Pending => return Poll::Pending,
159 };
160
161 if (buf[*pos - 1] & 0x80) == 0 {
162 let (len, _) = unsigned_varint::decode::u16(buf)
164 .map_err(|e| {
165 log::debug!("invalid length prefix: {}", e);
166 io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
167 })?;
168
169 if len >= 1 {
170 *this.read_state = ReadState::ReadData { len, pos: 0 };
171 this.read_buffer.resize(len as usize, 0);
172 } else {
173 debug_assert_eq!(len, 0);
174 *this.read_state = ReadState::default();
175 return Poll::Ready(Some(Ok(Bytes::new())));
176 }
177 } else if *pos == MAX_LEN_BYTES as usize {
178 return Poll::Ready(Some(Err(io::Error::new(
181 io::ErrorKind::InvalidData,
182 "Maximum frame length exceeded"))));
183 }
184 }
185 ReadState::ReadData { len, pos } => {
186 match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
187 Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
188 Poll::Ready(Ok(n)) => *pos += n,
189 Poll::Pending => return Poll::Pending,
190 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
191 };
192
193 if *pos == *len as usize {
194 let frame = this.read_buffer.split_off(0).freeze();
196 *this.read_state = ReadState::default();
197 return Poll::Ready(Some(Ok(frame)));
198 }
199 }
200 }
201 }
202 }
203}
204
205impl<R> Sink<Bytes> for LengthDelimited<R>
206where
207 R: AsyncWrite,
208{
209 type Error = io::Error;
210
211 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212 if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
216 match self.as_mut().poll_write_buffer(cx) {
217 Poll::Ready(Ok(())) => {},
218 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
219 Poll::Pending => return Poll::Pending,
220 }
221
222 debug_assert!(self.as_mut().project().write_buffer.is_empty());
223 }
224
225 Poll::Ready(Ok(()))
226 }
227
228 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
229 let this = self.project();
230
231 let len = match u16::try_from(item.len()) {
232 Ok(len) if len <= MAX_FRAME_SIZE => len,
233 _ => {
234 return Err(io::Error::new(
235 io::ErrorKind::InvalidData,
236 "Maximum frame size exceeded."))
237 }
238 };
239
240 let mut uvi_buf = unsigned_varint::encode::u16_buffer();
241 let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
242 this.write_buffer.reserve(len as usize + uvi_len.len());
243 this.write_buffer.put(uvi_len);
244 this.write_buffer.put(item);
245
246 Ok(())
247 }
248
249 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
250 match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
252 Poll::Ready(Ok(())) => {},
253 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
254 Poll::Pending => return Poll::Pending,
255 }
256
257 let this = self.project();
258 debug_assert!(this.write_buffer.is_empty());
259
260 this.inner.poll_flush(cx)
262 }
263
264 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
265 match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
267 Poll::Ready(Ok(())) => {},
268 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
269 Poll::Pending => return Poll::Pending,
270 }
271
272 let this = self.project();
273 debug_assert!(this.write_buffer.is_empty());
274
275 this.inner.poll_close(cx)
277 }
278}
279
280#[pin_project::pin_project]
283#[derive(Debug)]
284pub struct LengthDelimitedReader<R> {
285 #[pin]
286 inner: LengthDelimited<R>
287}
288
289impl<R> LengthDelimitedReader<R> {
290 pub fn into_inner(self) -> R {
303 self.inner.into_inner()
304 }
305}
306
307impl<R> Stream for LengthDelimitedReader<R>
308where
309 R: AsyncRead
310{
311 type Item = Result<Bytes, io::Error>;
312
313 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
314 self.project().inner.poll_next(cx)
315 }
316}
317
318impl<R> AsyncWrite for LengthDelimitedReader<R>
319where
320 R: AsyncWrite
321{
322 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8])
323 -> Poll<Result<usize, io::Error>>
324 {
325 let mut this = self.project().inner;
327
328 match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
330 Poll::Ready(Ok(())) => {},
331 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
332 Poll::Pending => return Poll::Pending,
333 }
334 debug_assert!(this.write_buffer.is_empty());
335
336 this.project().inner.poll_write(cx, buf)
337 }
338
339 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
340 self.project().inner.poll_flush(cx)
341 }
342
343 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
344 self.project().inner.poll_close(cx)
345 }
346
347 fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
348 -> Poll<Result<usize, io::Error>>
349 {
350 let mut this = self.project().inner;
352
353 match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
355 Poll::Ready(Ok(())) => {},
356 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
357 Poll::Pending => return Poll::Pending,
358 }
359 debug_assert!(this.write_buffer.is_empty());
360
361 this.project().inner.poll_write_vectored(cx, bufs)
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use crate::length_delimited::LengthDelimited;
368 use async_std::net::{TcpListener, TcpStream};
369 use futures::{prelude::*, io::Cursor};
370 use quickcheck::*;
371 use std::io::ErrorKind;
372
373 #[test]
374 fn basic_read() {
375 let data = vec![6, 9, 8, 7, 6, 5, 4];
376 let framed = LengthDelimited::new(Cursor::new(data));
377 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
378 assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
379 }
380
381 #[test]
382 fn basic_read_two() {
383 let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
384 let framed = LengthDelimited::new(Cursor::new(data));
385 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
386 assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
387 }
388
389 #[test]
390 fn two_bytes_long_packet() {
391 let len = 5000u16;
392 assert!(len < (1 << 15));
393 let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
394 let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
395 data.extend(frame.clone().into_iter());
396 let mut framed = LengthDelimited::new(Cursor::new(data));
397 let recved = futures::executor::block_on(async move {
398 framed.next().await
399 }).unwrap();
400 assert_eq!(recved.unwrap(), frame);
401 }
402
403 #[test]
404 fn packet_len_too_long() {
405 let mut data = vec![0x81, 0x81, 0x1];
406 data.extend((0..16513).map(|_| 0));
407 let mut framed = LengthDelimited::new(Cursor::new(data));
408 let recved = futures::executor::block_on(async move {
409 framed.next().await.unwrap()
410 });
411
412 if let Err(io_err) = recved {
413 assert_eq!(io_err.kind(), ErrorKind::InvalidData)
414 } else {
415 panic!()
416 }
417 }
418
419 #[test]
420 fn empty_frames() {
421 let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
422 let framed = LengthDelimited::new(Cursor::new(data));
423 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
424 assert_eq!(
425 recved,
426 vec![
427 vec![],
428 vec![],
429 vec![9, 8, 7, 6, 5, 4],
430 vec![],
431 vec![9, 8, 7],
432 ]
433 );
434 }
435
436 #[test]
437 fn unexpected_eof_in_len() {
438 let data = vec![0x89];
439 let framed = LengthDelimited::new(Cursor::new(data));
440 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
441 if let Err(io_err) = recved {
442 assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
443 } else {
444 panic!()
445 }
446 }
447
448 #[test]
449 fn unexpected_eof_in_data() {
450 let data = vec![5];
451 let framed = LengthDelimited::new(Cursor::new(data));
452 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
453 if let Err(io_err) = recved {
454 assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
455 } else {
456 panic!()
457 }
458 }
459
460 #[test]
461 fn unexpected_eof_in_data2() {
462 let data = vec![5, 9, 8, 7];
463 let framed = LengthDelimited::new(Cursor::new(data));
464 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
465 if let Err(io_err) = recved {
466 assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
467 } else {
468 panic!()
469 }
470 }
471
472 #[test]
473 fn writing_reading() {
474 fn prop(frames: Vec<Vec<u8>>) -> TestResult {
475 async_std::task::block_on(async move {
476 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
477 let listener_addr = listener.local_addr().unwrap();
478
479 let expected_frames = frames.clone();
480 let server = async_std::task::spawn(async move {
481 let socket = listener.accept().await.unwrap().0;
482 let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
483
484 let mut buf = vec![0u8; 0];
485 for expected in expected_frames {
486 if expected.is_empty() {
487 continue;
488 }
489 if buf.len() < expected.len() {
490 buf.resize(expected.len(), 0);
491 }
492 let n = connec.read(&mut buf).await.unwrap();
493 assert_eq!(&buf[..n], &expected[..]);
494 }
495 });
496
497 let client = async_std::task::spawn(async move {
498 let socket = TcpStream::connect(&listener_addr).await.unwrap();
499 let mut connec = LengthDelimited::new(socket);
500 for frame in frames {
501 connec.send(From::from(frame)).await.unwrap();
502 }
503 });
504
505 server.await;
506 client.await;
507 });
508
509 TestResult::passed()
510 }
511
512 quickcheck(prop as fn(_) -> _)
513 }
514}