1use core::result::Result;
2
3use crate::{
4 codec::{self, Decodable, Encodable},
5 error::{DecodeError, EncodeError},
6 status::Status,
7};
8
9mod packet_type;
10mod packet_flags;
11
12pub use self::{
13 packet_type::PacketType,
14 packet_flags::{
15 PacketFlags,
16 PublishFlags,
17 },
18};
19
20#[derive(Copy, Clone, PartialEq, Eq, Debug)]
21pub struct FixedHeader {
22 r#type: PacketType,
23 flags: PacketFlags,
24 len: u32,
25}
26
27impl FixedHeader {
28 pub fn new(r#type: PacketType, flags: PacketFlags, len: u32) -> Self {
29 FixedHeader {
30 r#type,
31 flags,
32 len
33 }
34 }
35
36 pub fn r#type(&self) -> PacketType {
37 self.r#type
38 }
39
40 pub fn flags(&self) -> PacketFlags {
41 self.flags
42 }
43
44 pub fn len(&self) -> u32 {
45 self.len
46 }
47}
48
49impl<'buf> Decodable<'buf> for FixedHeader {
50 fn decode(bytes: &'buf [u8]) -> Result<Status<(usize, Self)>, DecodeError> {
51 if bytes.len() < 2 {
53 return Ok(Status::Partial(2 - bytes.len()));
54 }
55
56 let (r#type, flags) = parse_packet_type(bytes[0])?;
57
58 let offset = 1;
59
60 let (offset, len) = read!(parse_remaining_length, bytes, offset);
61
62 Ok(Status::Complete((offset, Self {
63 r#type,
64 flags,
65 len
66 })))
67 }
68}
69
70impl Encodable for FixedHeader {
71 fn encoded_len(&self) -> usize {
72 let mut buf = [0u8; 4];
73 let u = encode_remaining_length(self.len, &mut buf);
74 1 + u
75 }
76
77 fn encode(&self, bytes: &mut [u8]) -> Result<usize, EncodeError> {
78 let offset = 0;
79 let offset = {
80 let o = codec::values::encode_u8(encode_packet_type(self.r#type, self.flags), &mut bytes[offset..])?;
81 offset + o
82 };
83 let offset = {
84 let mut remaining_length = [0u8; 4];
85 let o = encode_remaining_length(self.len, &mut remaining_length);
86 (&mut bytes[offset..offset+o]).copy_from_slice(&remaining_length[..o]);
87 offset + o
88 };
89 Ok(offset)
90 }
91}
92
93fn parse_remaining_length(bytes: &[u8]) -> Result<Status<(usize, u32)>, DecodeError> {
94 let mut multiplier = 1;
95 let mut value = 0u32;
96 let mut index = 0;
97
98 loop {
99 if multiplier > 128 * 128 * 128 {
100 return Err(DecodeError::RemainingLength);
101 }
102
103 if index >= bytes.len() {
104 return Ok(Status::Partial(1));
105 }
106
107 let byte = bytes[index];
108 index += 1;
109
110 value += (byte & 0b01111111) as u32 * multiplier;
111
112 multiplier *= 128;
113
114 if byte & 128 == 0 {
115 return Ok(Status::Complete((index, value)));
116 }
117 }
118}
119
120fn encode_remaining_length(mut len: u32, buf: &mut [u8; 4]) -> usize {
121 let mut index = 0;
122 loop {
123 let mut byte = len as u8 % 128;
124 len /= 128;
125 if len > 0 {
126 byte |= 128;
127 }
128 buf[index] = byte;
129 index = index + 1;
130
131 if len == 0 {
132 break index;
133 }
134 }
135}
136
137fn parse_packet_type(inp: u8) -> Result<(PacketType, PacketFlags), DecodeError> {
138 let packet_type = match (inp & 0xF0) >> 4 {
140 1 => PacketType::Connect,
141 2 => PacketType::Connack,
142 3 => PacketType::Publish,
143 4 => PacketType::Puback,
144 5 => PacketType::Pubrec,
145 6 => PacketType::Pubrel,
146 7 => PacketType::Pubcomp,
147 8 => PacketType::Subscribe,
148 9 => PacketType::Suback,
149 10 => PacketType::Unsubscribe,
150 11 => PacketType::Unsuback,
151 12 => PacketType::Pingreq,
152 13 => PacketType::Pingresp,
153 14 => PacketType::Disconnect,
154 _ => return Err(DecodeError::PacketType),
155 };
156
157 let flags = PacketFlags(inp & 0xF);
159
160 validate_flag(packet_type, flags)
161}
162
163fn encode_packet_type(r#type: PacketType, flags: PacketFlags) -> u8 {
164 let packet_type: u8 = match r#type {
165 PacketType::Connect => 1,
166 PacketType::Connack => 2,
167 PacketType::Publish => 3,
168 PacketType::Puback => 4,
169 PacketType::Pubrec => 5,
170 PacketType::Pubrel => 6,
171 PacketType::Pubcomp => 7,
172 PacketType::Subscribe => 8,
173 PacketType::Suback => 9,
174 PacketType::Unsubscribe => 10,
175 PacketType::Unsuback => 11,
176 PacketType::Pingreq => 12,
177 PacketType::Pingresp => 13,
178 PacketType::Disconnect => 14,
179 };
180
181 (packet_type << 4) | flags.0
182}
183
184fn validate_flag(packet_type: PacketType, flags: PacketFlags) -> Result<(PacketType, PacketFlags), DecodeError> {
185 const ZERO_TYPES: &[PacketType] = &[
187 PacketType::Connect,
188 PacketType::Connack,
189 PacketType::Puback,
190 PacketType::Pubrec,
191 PacketType::Pubcomp,
192 PacketType::Suback,
193 PacketType::Unsuback,
194 PacketType::Pingreq,
195 PacketType::Pingresp,
196 PacketType::Disconnect,
197 ];
198 const ONE_TYPES: &[PacketType] = &[
200 PacketType::Pubrel,
201 PacketType::Subscribe,
202 PacketType::Unsubscribe,
203 ];
204
205 validate_flag_val(packet_type, flags, ZERO_TYPES, PacketFlags(0b0000))
206 .and_then(|_| validate_flag_val(packet_type, flags, ONE_TYPES, PacketFlags(0b0010)))
207}
208
209fn validate_flag_val(
210 packet_type: PacketType,
211 flags: PacketFlags,
212 types: &[PacketType],
213 expected_flags: PacketFlags,
214) -> Result<(PacketType, PacketFlags), DecodeError> {
215 if let Some(_) = types.iter().find(|&&v| v == packet_type) {
216 if flags != expected_flags {
217 return Err(DecodeError::PacketFlag);
218 }
219 }
220
221 Ok((packet_type, flags))
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use rayon::prelude::*;
228 use std::format;
229
230 #[test]
231 fn packet_type() {
232 let mut inputs: [([u8; 1], PacketType); 14] = [
233 ([01 << 4 | 0b0000], PacketType::Connect),
234 ([02 << 4 | 0b0000], PacketType::Connack),
235 ([03 << 4 | 0b0000], PacketType::Publish),
236 ([04 << 4 | 0b0000], PacketType::Puback),
237 ([05 << 4 | 0b0000], PacketType::Pubrec),
238 ([06 << 4 | 0b0010], PacketType::Pubrel),
239 ([07 << 4 | 0b0000], PacketType::Pubcomp),
240 ([08 << 4 | 0b0010], PacketType::Subscribe),
241 ([09 << 4 | 0b0000], PacketType::Suback),
242 ([10 << 4 | 0b0010], PacketType::Unsubscribe),
243 ([11 << 4 | 0b0000], PacketType::Unsuback),
244 ([12 << 4 | 0b0000], PacketType::Pingreq),
245 ([13 << 4 | 0b0000], PacketType::Pingresp),
246 ([14 << 4 | 0b0000], PacketType::Disconnect),
247 ];
248
249 for (buf, expected_type) in inputs.iter_mut() {
250 let expected_flag = PacketFlags(buf[0] & 0xF);
251 let (packet_type, flag) = parse_packet_type(buf[0]).unwrap();
252 assert_eq!(packet_type, *expected_type);
253 assert_eq!(flag, expected_flag);
254 }
255 }
256
257 #[test]
258 fn bad_packet_type() {
259 let result = parse_packet_type(15 << 4);
260 assert_eq!(result, Err(DecodeError::PacketType));
261 }
262
263 #[test]
264 fn bad_zero_flags() {
265 let mut inputs: [([u8; 1], PacketType); 10] = [
266 ([01 << 4 | 1], PacketType::Connect),
267 ([02 << 4 | 1], PacketType::Connack),
268 ([04 << 4 | 1], PacketType::Puback),
269 ([05 << 4 | 1], PacketType::Pubrec),
270 ([07 << 4 | 1], PacketType::Pubcomp),
271 ([09 << 4 | 1], PacketType::Suback),
272 ([11 << 4 | 1], PacketType::Unsuback),
273 ([12 << 4 | 1], PacketType::Pingreq),
274 ([13 << 4 | 1], PacketType::Pingresp),
275 ([14 << 4 | 1], PacketType::Disconnect),
276 ];
277 for (buf, _) in inputs.iter_mut() {
278 let result = parse_packet_type(buf[0]);
279 assert_eq!(result, Err(DecodeError::PacketFlag));
280 }
281 }
282
283 #[test]
284 fn bad_one_flags() {
285 let mut inputs: [([u8; 1], PacketType); 3] = [
286 ([06 << 4 | 0], PacketType::Pubrel),
287 ([08 << 4 | 0], PacketType::Subscribe),
288 ([10 << 4 | 0], PacketType::Unsubscribe),
289 ];
290 for (buf, _) in inputs.iter_mut() {
291 let result = parse_packet_type(buf[0]);
292 assert_eq!(result, Err(DecodeError::PacketFlag));
293 }
294 }
295
296 #[test]
297 fn publish_flags() {
298 for i in 0..15 {
299 let input = 03 << 4 | i;
300 let (packet_type, flag) = parse_packet_type(input).unwrap();
301 assert_eq!(packet_type, PacketType::Publish);
302 assert_eq!(flag, PacketFlags(i));
303 }
304 }
305
306 #[test]
307 #[ignore]
308 fn remaining_length() {
309 let _: u32 = (0u32..(268435455 + 1))
311 .into_par_iter()
312 .map(|i| {
313 let mut buf = [0u8; 4];
314 let expected_offset = encode_remaining_length(i, &mut buf);
315 let (offset, len) =
316 parse_remaining_length(&buf).expect(&format!("Failed for number: {}", i)).unwrap();
317 assert_eq!(i, len);
318 assert_eq!(expected_offset, offset);
319 0
320 })
321 .sum();
322 }
323
324 #[test]
325 fn bad_remaining_length() {
326 let buf = [0xFF, 0xFF, 0xFF, 0xFF];
327 let result = parse_remaining_length(&buf);
328 assert_eq!(result, Err(DecodeError::RemainingLength));
329 }
330
331 #[test]
332 fn bad_remaining_length2() {
333 let buf = [0xFF, 0xFF];
334 let result = parse_remaining_length(&buf);
335 assert_eq!(result, Ok(Status::Partial(1)));
336 }
337
338 #[test]
339 fn fixed_header1() {
340 let buf = [
341 01 << 4 | 0b0000, 0, ];
344 let (offset, header) = FixedHeader::decode(&buf).unwrap().unwrap();
345 assert_eq!(offset, 2);
346 assert_eq!(header.r#type(), PacketType::Connect);
347 assert_eq!(header.flags(), PacketFlags(0));
348 assert_eq!(header.len(), 0);
349 }
350
351 #[test]
352 fn fixed_header2() {
353 let buf = [
354 03 << 4 | 0b0000, 0x80, 0x80,
357 0x80,
358 0x1,
359 ];
360 let (offset, header) = FixedHeader::decode(&buf).unwrap().unwrap();
361 assert_eq!(offset, 5);
362 assert_eq!(header.r#type(), PacketType::Publish);
363 assert_eq!(header.flags(), PacketFlags(0));
364 assert_eq!(header.len(), 2097152);
365 }
366
367 #[test]
368 fn bad_len() {
369 let buf = [03 << 4 | 0];
370 let result = FixedHeader::decode(&buf);
371 assert_eq!(result, Ok(Status::Partial(1)));
372 }
373}