1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
use super::{Decoder, Encoder};
use bytes::{Buf, Bytes, BytesMut};
use std::{convert::TryFrom, marker::PhantomData};

pub struct LengthCodec<L>(PhantomData<L>);
impl_phantom!(LengthCodec<L>);

#[derive(Debug, thiserror::Error)]
#[error("length overflow")]
pub struct OverflowError;

impl<L> LengthCodec<L> {
    const HEADER_LEN: usize = std::mem::size_of::<L>();
}

pub trait Length {
    fn encode(x: usize, dst: &mut BytesMut) -> Result<(), OverflowError>;
    fn start_decode(src: &[u8]) -> Result<usize, OverflowError>;
}

macro_rules! impl_length {
    ($($x:ty => $y:expr),+ $(,)?) => {
        $(
        impl Length for $x {
            fn encode(x: usize, dst: &mut BytesMut) -> Result<(), OverflowError> {
                let this = Self::try_from(x).map_err(|_| OverflowError)?;
                dst.extend_from_slice(&Self::to_be_bytes(this));
                Ok(())
            }

            fn start_decode(src: &[u8]) -> Result<usize, OverflowError> {
                let mut len_bytes = [0u8; $y];
                len_bytes.copy_from_slice(&src[..$y]);
                usize::try_from(Self::from_be_bytes(len_bytes)).map_err(|_| OverflowError)
            }
        }
        )+
    }
}

impl_length!(u8 => 1, u16 => 2, u32 => 4, u64 => 8);

impl<L: Length> Encoder for LengthCodec<L> {
    type Error = OverflowError;
    type Item = Bytes;

    fn encode(&mut self, src: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
        dst.reserve(Self::HEADER_LEN + src.len());
        L::encode(src.len(), dst)?;
        dst.extend_from_slice(&src);
        Ok(())
    }
}

impl<L: Length> Decoder for LengthCodec<L> {
    type Error = OverflowError;
    type Item = Bytes;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        Ok(if src.len() < std::mem::size_of::<L>() {
            None
        } else {
            let len = L::start_decode(src)?;
            if src.len() - Self::HEADER_LEN >= len {
                // Skip the length header we already read.
                src.advance(Self::HEADER_LEN);
                Some(src.split_to(len).freeze())
            } else {
                None
            }
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    mod decode {
        use super::*;

        #[test]
        fn it_returns_bytes_withouth_length_header() {
            use bytes::BufMut;
            let mut codec = LengthCodec::<u64>::new();

            let mut src = BytesMut::with_capacity(5);
            src.put(&[0, 0, 0, 0, 0, 0, 0, 3u8, 1, 2, 3, 4][..]);
            let item = codec.decode(&mut src).unwrap();

            assert!(item == Some(Bytes::from(&[1u8, 2, 3][..])));
        }
    }
}