use futures::{Async, Poll, Sink, StartSend, Stream};
use smallvec::SmallVec;
use std::{io::{Error as IoError, ErrorKind as IoErrorKind}, marker::PhantomData, u16};
use tokio_codec::FramedWrite;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec::UviBytes;
pub struct LengthDelimited<I, S> {
inner: FramedWrite<S, UviBytes>,
internal_buffer: SmallVec<[u8; 64]>,
internal_buffer_pos: usize,
state: State,
marker: PhantomData<I>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum State {
ReadingLength,
ReadingData { frame_len: u16 },
}
impl<I, S> LengthDelimited<I, S>
where
S: AsyncWrite
{
pub fn new(inner: S) -> LengthDelimited<I, S> {
let mut encoder = UviBytes::default();
encoder.set_max_len(usize::from(u16::MAX));
LengthDelimited {
inner: FramedWrite::new(inner, encoder),
internal_buffer: {
let mut v = SmallVec::new();
v.push(0);
v
},
internal_buffer_pos: 0,
state: State::ReadingLength,
marker: PhantomData,
}
}
#[inline]
pub fn into_inner(self) -> S {
assert_eq!(self.state, State::ReadingLength);
assert_eq!(self.internal_buffer_pos, 0);
self.inner.into_inner()
}
}
impl<I, S> Stream for LengthDelimited<I, S>
where
S: AsyncRead,
I: for<'r> From<&'r [u8]>,
{
type Item = I;
type Error = IoError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
debug_assert!(!self.internal_buffer.is_empty());
debug_assert!(self.internal_buffer_pos < self.internal_buffer.len());
match self.state {
State::ReadingLength => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
Ok(0) => {
if self.internal_buffer_pos == 0 {
return Ok(Async::Ready(None));
} else {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
}
}
Ok(n) => {
debug_assert_eq!(n, 1);
self.internal_buffer_pos += n;
}
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
return Ok(Async::NotReady);
}
Err(err) => {
return Err(err);
}
};
debug_assert_eq!(self.internal_buffer.len(), self.internal_buffer_pos);
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 {
let frame_len = decode_length_prefix(&self.internal_buffer);
if frame_len >= 1 {
self.state = State::ReadingData { frame_len };
self.internal_buffer.clear();
self.internal_buffer.reserve(frame_len as usize);
self.internal_buffer.extend((0..frame_len).map(|_| 0));
self.internal_buffer_pos = 0;
} else {
debug_assert_eq!(frame_len, 0);
self.state = State::ReadingLength;
self.internal_buffer.clear();
self.internal_buffer.push(0);
self.internal_buffer_pos = 0;
return Ok(Async::Ready(Some(From::from(&[][..]))));
}
} else if self.internal_buffer_pos >= 2 {
return Err(IoError::new(
IoErrorKind::InvalidData,
"frame length too long",
));
} else {
self.internal_buffer.push(0);
}
}
State::ReadingData { frame_len } => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
Ok(0) => {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
}
Ok(n) => self.internal_buffer_pos += n,
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
return Ok(Async::NotReady);
}
Err(err) => {
return Err(err);
}
};
if self.internal_buffer_pos >= frame_len as usize {
self.state = State::ReadingLength;
let out_data = From::from(&self.internal_buffer[..]);
self.internal_buffer.clear();
self.internal_buffer.push(0);
self.internal_buffer_pos = 0;
return Ok(Async::Ready(Some(out_data)));
}
}
}
}
}
}
impl<I, S> Sink for LengthDelimited<I, S>
where
S: AsyncWrite
{
type SinkItem = <FramedWrite<S, UviBytes> as Sink>::SinkItem;
type SinkError = <FramedWrite<S, UviBytes> as Sink>::SinkError;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
self.inner.start_send(item)
}
#[inline]
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
self.inner.poll_complete()
}
#[inline]
fn close(&mut self) -> Poll<(), Self::SinkError> {
self.inner.close()
}
}
fn decode_length_prefix(buf: &[u8]) -> u16 {
debug_assert!(buf.len() <= 2);
let mut sum = 0u16;
for &byte in buf.iter().rev() {
let byte = byte & 0x7f;
sum <<= 7;
debug_assert!(sum.checked_add(u16::from(byte)).is_some());
sum += u16::from(byte);
}
sum
}
#[cfg(test)]
mod tests {
use futures::{Future, Stream};
use crate::length_delimited::LengthDelimited;
use std::io::Cursor;
use std::io::ErrorKind;
#[test]
fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
}
#[test]
fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
}
#[test]
fn two_bytes_long_packet() {
let len = 5000u16;
assert!(len < (1 << 15));
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter());
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed
.into_future()
.map(|(m, _)| m)
.map_err(|_| ())
.wait()
.unwrap();
assert_eq!(recved.unwrap(), frame);
}
#[test]
fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed
.into_future()
.map(|(m, _)| m)
.map_err(|(err, _)| err)
.wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::InvalidData),
_ => panic!(),
}
}
#[test]
fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
assert_eq!(
recved,
vec![
vec![],
vec![],
vec![9, 8, 7, 6, 5, 4],
vec![],
vec![9, 8, 7],
]
);
}
#[test]
fn unexpected_eof_in_len() {
let data = vec![0x89];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
}
}
#[test]
fn unexpected_eof_in_data() {
let data = vec![5];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
}
}
#[test]
fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
}
}
}