mqtt_codec_kit/common/
encodable.rs

1use std::{
2    convert::Infallible,
3    error::Error,
4    io::{self, Read, Write},
5    marker::Sized,
6    slice,
7};
8
9use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10
11pub trait Encodable {
12    /// Encodes to writer
13    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()>;
14    /// Length of bytes after encoded
15    fn encoded_length(&self) -> u32;
16}
17
18impl<T: Encodable> Encodable for Option<T> {
19    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
20        if let Some(this) = self {
21            this.encode(writer)?
22        }
23        Ok(())
24    }
25
26    fn encoded_length(&self) -> u32 {
27        self.as_ref().map_or(0, |x| x.encoded_length())
28    }
29}
30
31impl<'a> Encodable for &'a str {
32    fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
33        assert!(self.as_bytes().len() <= u16::MAX as usize);
34
35        writer
36            .write_u16::<BigEndian>(self.as_bytes().len() as u16)
37            .and_then(|_| writer.write_all(self.as_bytes()))
38    }
39
40    fn encoded_length(&self) -> u32 {
41        2 + self.as_bytes().len() as u32
42    }
43}
44
45impl<'a> Encodable for &'a [u8] {
46    fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
47        writer.write_all(self)
48    }
49
50    fn encoded_length(&self) -> u32 {
51        self.len() as u32
52    }
53}
54
55impl Encodable for String {
56    fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
57        (&self[..]).encode(writer)
58    }
59
60    fn encoded_length(&self) -> u32 {
61        (&self[..]).encoded_length()
62    }
63}
64
65impl Encodable for Vec<u8> {
66    fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
67        (&self[..]).encode(writer)
68    }
69
70    fn encoded_length(&self) -> u32 {
71        (&self[..]).encoded_length()
72    }
73}
74
75impl Encodable for () {
76    fn encode<W: Write>(&self, _: &mut W) -> Result<(), io::Error> {
77        Ok(())
78    }
79
80    fn encoded_length(&self) -> u32 {
81        0
82    }
83}
84
85/// Methods for decoding bytes to an Object according to MQTT specification
86pub trait Decodable: Sized {
87    type Error: Error;
88    type Cond;
89
90    /// Decodes object from reader
91    fn decode<R: Read>(reader: &mut R) -> Result<Self, Self::Error>
92    where
93        Self::Cond: Default,
94    {
95        Self::decode_with(reader, Default::default())
96    }
97
98    /// Decodes object with additional data (or hints)
99    fn decode_with<R: Read>(reader: &mut R, cond: Self::Cond) -> Result<Self, Self::Error>;
100}
101
102impl Decodable for String {
103    type Error = io::Error;
104    type Cond = ();
105
106    fn decode_with<R: Read>(reader: &mut R, _rest: ()) -> Result<String, io::Error> {
107        let VarBytes(buf) = VarBytes::decode(reader)?;
108
109        String::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
110    }
111}
112
113impl Decodable for Vec<u8> {
114    type Error = io::Error;
115    type Cond = Option<u32>;
116
117    fn decode_with<R: Read>(reader: &mut R, length: Option<u32>) -> Result<Vec<u8>, io::Error> {
118        match length {
119            Some(length) => {
120                let mut buf = Vec::with_capacity(length as usize);
121                reader.take(length.into()).read_to_end(&mut buf)?;
122                Ok(buf)
123            }
124            None => {
125                let mut buf = Vec::new();
126                reader.read_to_end(&mut buf)?;
127                Ok(buf)
128            }
129        }
130    }
131}
132
133impl Decodable for () {
134    type Error = Infallible;
135    type Cond = ();
136
137    fn decode_with<R: Read>(_: &mut R, _: ()) -> Result<(), Self::Error> {
138        Ok(())
139    }
140}
141
142/// Bytes that encoded with length
143#[derive(Debug, Eq, PartialEq, Clone)]
144pub struct VarBytes(pub Vec<u8>);
145
146impl Encodable for VarBytes {
147    fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
148        assert!(self.0.len() <= u16::MAX as usize);
149        let len = self.0.len() as u16;
150        writer.write_u16::<BigEndian>(len)?;
151        writer.write_all(&self.0)?;
152        Ok(())
153    }
154
155    fn encoded_length(&self) -> u32 {
156        2 + self.0.len() as u32
157    }
158}
159
160impl Decodable for VarBytes {
161    type Error = io::Error;
162    type Cond = ();
163
164    fn decode_with<R: Read>(reader: &mut R, _: ()) -> Result<VarBytes, io::Error> {
165        let length = reader.read_u16::<BigEndian>()?;
166        let mut buf = Vec::with_capacity(length as usize);
167        reader.take(length.into()).read_to_end(&mut buf)?;
168        Ok(VarBytes(buf))
169    }
170}
171
172#[derive(Debug, Clone, PartialEq, Eq, Default)]
173pub struct VarInt(pub u32);
174
175impl Encodable for VarInt {
176    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
177        let mut value = self.0;
178        loop {
179            let mut byte = (value % 128) as u8;
180            value /= 128;
181            if value > 0 {
182                byte |= 128;
183            }
184            writer.write_u8(byte)?;
185            if value == 0 {
186                break;
187            }
188        }
189        Ok(())
190    }
191
192    fn encoded_length(&self) -> u32 {
193        if self.0 >= 2_097_152 {
194            4
195        } else if self.0 >= 16_384 {
196            3
197        } else if self.0 >= 128 {
198            2
199        } else {
200            1
201        }
202    }
203}
204
205impl Decodable for VarInt {
206    type Error = io::Error;
207    type Cond = ();
208
209    fn decode_with<R: Read>(reader: &mut R, _cond: Self::Cond) -> Result<Self, Self::Error> {
210        let mut byte = 0u8;
211        let mut var_int: u32 = 0;
212        let mut i: usize = 0;
213        loop {
214            reader.read_exact(slice::from_mut(&mut byte))?;
215            var_int |= (u32::from(byte) & 0x7F) << (7 * i);
216            if byte & 0x80 == 0 {
217                break;
218            } else if i < 3 {
219                i += 1;
220            } else {
221                return Err(io::Error::from(io::ErrorKind::InvalidData));
222            }
223        }
224        Ok(Self(var_int))
225    }
226}
227
228#[cfg(test)]
229mod test {
230    use super::*;
231
232    use std::io::Cursor;
233
234    #[test]
235    fn varbyte_encode() {
236        let test_var = vec![0, 1, 2, 3, 4, 5];
237        let bytes = VarBytes(test_var);
238
239        assert_eq!(bytes.encoded_length() as usize, 2 + 6);
240
241        let mut buf = Vec::new();
242        bytes.encode(&mut buf).unwrap();
243
244        assert_eq!(&buf, &[0, 6, 0, 1, 2, 3, 4, 5]);
245
246        let mut reader = Cursor::new(buf);
247        let decoded = VarBytes::decode(&mut reader).unwrap();
248
249        assert_eq!(decoded, bytes);
250    }
251}