mqtt_codec_kit/common/
encodable.rs1use std::{
2 convert::Infallible,
3 error::Error,
4 io::{self, Read, Write},
5 marker::Sized,
6 slice,
7};
8
9use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
10
11pub trait Encodable {
12 fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()>;
14 fn encoded_length(&self) -> u32;
16}
17
18impl<T: Encodable> Encodable for Option<T> {
19 fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
20 if let Some(this) = self {
21 this.encode(writer)?
22 }
23 Ok(())
24 }
25
26 fn encoded_length(&self) -> u32 {
27 self.as_ref().map_or(0, |x| x.encoded_length())
28 }
29}
30
31impl<'a> Encodable for &'a str {
32 fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
33 assert!(self.as_bytes().len() <= u16::MAX as usize);
34
35 writer
36 .write_u16::<BigEndian>(self.as_bytes().len() as u16)
37 .and_then(|_| writer.write_all(self.as_bytes()))
38 }
39
40 fn encoded_length(&self) -> u32 {
41 2 + self.as_bytes().len() as u32
42 }
43}
44
45impl<'a> Encodable for &'a [u8] {
46 fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
47 writer.write_all(self)
48 }
49
50 fn encoded_length(&self) -> u32 {
51 self.len() as u32
52 }
53}
54
55impl Encodable for String {
56 fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
57 (&self[..]).encode(writer)
58 }
59
60 fn encoded_length(&self) -> u32 {
61 (&self[..]).encoded_length()
62 }
63}
64
65impl Encodable for Vec<u8> {
66 fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
67 (&self[..]).encode(writer)
68 }
69
70 fn encoded_length(&self) -> u32 {
71 (&self[..]).encoded_length()
72 }
73}
74
75impl Encodable for () {
76 fn encode<W: Write>(&self, _: &mut W) -> Result<(), io::Error> {
77 Ok(())
78 }
79
80 fn encoded_length(&self) -> u32 {
81 0
82 }
83}
84
85pub trait Decodable: Sized {
87 type Error: Error;
88 type Cond;
89
90 fn decode<R: Read>(reader: &mut R) -> Result<Self, Self::Error>
92 where
93 Self::Cond: Default,
94 {
95 Self::decode_with(reader, Default::default())
96 }
97
98 fn decode_with<R: Read>(reader: &mut R, cond: Self::Cond) -> Result<Self, Self::Error>;
100}
101
102impl Decodable for String {
103 type Error = io::Error;
104 type Cond = ();
105
106 fn decode_with<R: Read>(reader: &mut R, _rest: ()) -> Result<String, io::Error> {
107 let VarBytes(buf) = VarBytes::decode(reader)?;
108
109 String::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
110 }
111}
112
113impl Decodable for Vec<u8> {
114 type Error = io::Error;
115 type Cond = Option<u32>;
116
117 fn decode_with<R: Read>(reader: &mut R, length: Option<u32>) -> Result<Vec<u8>, io::Error> {
118 match length {
119 Some(length) => {
120 let mut buf = Vec::with_capacity(length as usize);
121 reader.take(length.into()).read_to_end(&mut buf)?;
122 Ok(buf)
123 }
124 None => {
125 let mut buf = Vec::new();
126 reader.read_to_end(&mut buf)?;
127 Ok(buf)
128 }
129 }
130 }
131}
132
133impl Decodable for () {
134 type Error = Infallible;
135 type Cond = ();
136
137 fn decode_with<R: Read>(_: &mut R, _: ()) -> Result<(), Self::Error> {
138 Ok(())
139 }
140}
141
142#[derive(Debug, Eq, PartialEq, Clone)]
144pub struct VarBytes(pub Vec<u8>);
145
146impl Encodable for VarBytes {
147 fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
148 assert!(self.0.len() <= u16::MAX as usize);
149 let len = self.0.len() as u16;
150 writer.write_u16::<BigEndian>(len)?;
151 writer.write_all(&self.0)?;
152 Ok(())
153 }
154
155 fn encoded_length(&self) -> u32 {
156 2 + self.0.len() as u32
157 }
158}
159
160impl Decodable for VarBytes {
161 type Error = io::Error;
162 type Cond = ();
163
164 fn decode_with<R: Read>(reader: &mut R, _: ()) -> Result<VarBytes, io::Error> {
165 let length = reader.read_u16::<BigEndian>()?;
166 let mut buf = Vec::with_capacity(length as usize);
167 reader.take(length.into()).read_to_end(&mut buf)?;
168 Ok(VarBytes(buf))
169 }
170}
171
172#[derive(Debug, Clone, PartialEq, Eq, Default)]
173pub struct VarInt(pub u32);
174
175impl Encodable for VarInt {
176 fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
177 let mut value = self.0;
178 loop {
179 let mut byte = (value % 128) as u8;
180 value /= 128;
181 if value > 0 {
182 byte |= 128;
183 }
184 writer.write_u8(byte)?;
185 if value == 0 {
186 break;
187 }
188 }
189 Ok(())
190 }
191
192 fn encoded_length(&self) -> u32 {
193 if self.0 >= 2_097_152 {
194 4
195 } else if self.0 >= 16_384 {
196 3
197 } else if self.0 >= 128 {
198 2
199 } else {
200 1
201 }
202 }
203}
204
205impl Decodable for VarInt {
206 type Error = io::Error;
207 type Cond = ();
208
209 fn decode_with<R: Read>(reader: &mut R, _cond: Self::Cond) -> Result<Self, Self::Error> {
210 let mut byte = 0u8;
211 let mut var_int: u32 = 0;
212 let mut i: usize = 0;
213 loop {
214 reader.read_exact(slice::from_mut(&mut byte))?;
215 var_int |= (u32::from(byte) & 0x7F) << (7 * i);
216 if byte & 0x80 == 0 {
217 break;
218 } else if i < 3 {
219 i += 1;
220 } else {
221 return Err(io::Error::from(io::ErrorKind::InvalidData));
222 }
223 }
224 Ok(Self(var_int))
225 }
226}
227
228#[cfg(test)]
229mod test {
230 use super::*;
231
232 use std::io::Cursor;
233
234 #[test]
235 fn varbyte_encode() {
236 let test_var = vec![0, 1, 2, 3, 4, 5];
237 let bytes = VarBytes(test_var);
238
239 assert_eq!(bytes.encoded_length() as usize, 2 + 6);
240
241 let mut buf = Vec::new();
242 bytes.encode(&mut buf).unwrap();
243
244 assert_eq!(&buf, &[0, 6, 0, 1, 2, 3, 4, 5]);
245
246 let mut reader = Cursor::new(buf);
247 let decoded = VarBytes::decode(&mut reader).unwrap();
248
249 assert_eq!(decoded, bytes);
250 }
251}