async_codec_lite/codec/
length.rs

1use super::{Decoder, Encoder};
2use bytes::{Buf, Bytes, BytesMut};
3use std::{convert::TryFrom, marker::PhantomData};
4
5pub struct LengthCodec<L>(PhantomData<L>);
6impl_phantom!(LengthCodec<L>);
7
8#[derive(Debug, thiserror::Error)]
9#[error("length overflow")]
10pub struct OverflowError;
11
12impl<L> LengthCodec<L> {
13    const HEADER_LEN: usize = std::mem::size_of::<L>();
14}
15
16pub trait Length {
17    fn encode(x: usize, dst: &mut BytesMut) -> Result<(), OverflowError>;
18    fn start_decode(src: &[u8]) -> Result<usize, OverflowError>;
19}
20
21macro_rules! impl_length {
22    ($($x:ty => $y:expr),+ $(,)?) => {
23        $(
24        impl Length for $x {
25            fn encode(x: usize, dst: &mut BytesMut) -> Result<(), OverflowError> {
26                let this = Self::try_from(x).map_err(|_| OverflowError)?;
27                dst.extend_from_slice(&Self::to_be_bytes(this));
28                Ok(())
29            }
30
31            fn start_decode(src: &[u8]) -> Result<usize, OverflowError> {
32                let mut len_bytes = [0u8; $y];
33                len_bytes.copy_from_slice(&src[..$y]);
34                usize::try_from(Self::from_be_bytes(len_bytes)).map_err(|_| OverflowError)
35            }
36        }
37        )+
38    }
39}
40
41impl_length!(u8 => 1, u16 => 2, u32 => 4, u64 => 8);
42
43impl<L: Length> Encoder for LengthCodec<L> {
44    type Error = OverflowError;
45    type Item = Bytes;
46
47    fn encode(&mut self, src: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
48        dst.reserve(Self::HEADER_LEN + src.len());
49        L::encode(src.len(), dst)?;
50        dst.extend_from_slice(&src);
51        Ok(())
52    }
53}
54
55impl<L: Length> Decoder for LengthCodec<L> {
56    type Error = OverflowError;
57    type Item = Bytes;
58
59    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
60        Ok(if src.len() < std::mem::size_of::<L>() {
61            None
62        } else {
63            let len = L::start_decode(src)?;
64            if src.len() - Self::HEADER_LEN >= len {
65                // Skip the length header we already read.
66                src.advance(Self::HEADER_LEN);
67                Some(src.split_to(len).freeze())
68            } else {
69                None
70            }
71        })
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    mod decode {
80        use super::*;
81
82        #[test]
83        fn it_returns_bytes_withouth_length_header() {
84            use bytes::BufMut;
85            let mut codec = LengthCodec::<u64>::new();
86
87            let mut src = BytesMut::with_capacity(5);
88            src.put(&[0, 0, 0, 0, 0, 0, 0, 3u8, 1, 2, 3, 4][..]);
89            let item = codec.decode(&mut src).unwrap();
90
91            assert!(item == Some(Bytes::from(&[1u8, 2, 3][..])));
92        }
93    }
94}