embedded_mqtt/fixed_header/
mod.rs

1use core::result::Result;
2
3use crate::{
4    codec::{self, Decodable, Encodable},
5    error::{DecodeError, EncodeError},
6    status::Status,
7};
8
9mod packet_type;
10mod packet_flags;
11
12pub use self::{
13    packet_type::PacketType,
14    packet_flags::{
15        PacketFlags,
16        PublishFlags,
17    },
18};
19
20#[derive(Copy, Clone, PartialEq, Eq, Debug)]
21pub struct FixedHeader {
22    r#type: PacketType,
23    flags: PacketFlags,
24    len: u32,
25}
26
27impl FixedHeader {
28    pub fn new(r#type: PacketType, flags: PacketFlags, len: u32) -> Self {
29        FixedHeader {
30            r#type,
31            flags,
32            len
33        }
34    }
35
36    pub fn r#type(&self) -> PacketType {
37        self.r#type
38    }
39
40    pub fn flags(&self) -> PacketFlags {
41        self.flags
42    }
43
44    pub fn len(&self) -> u32 {
45        self.len
46    }
47}
48
49impl<'buf> Decodable<'buf> for FixedHeader {
50    fn decode(bytes: &'buf [u8]) -> Result<Status<(usize, Self)>, DecodeError> {
51        // "bytes" must be at least 2 bytes long to be a valid fixed header
52        if bytes.len() < 2 {
53            return Ok(Status::Partial(2 - bytes.len()));
54        }
55
56        let (r#type, flags) = parse_packet_type(bytes[0])?;
57
58        let offset = 1;
59
60        let (offset, len) = read!(parse_remaining_length, bytes, offset);
61
62        Ok(Status::Complete((offset, Self {
63            r#type,
64            flags,
65            len
66        })))
67    }
68}
69
70impl Encodable for FixedHeader {
71    fn encoded_len(&self) -> usize {
72        let mut buf = [0u8; 4];
73        let u = encode_remaining_length(self.len, &mut buf);
74        1 + u
75    }
76
77    fn encode(&self, bytes: &mut [u8]) -> Result<usize, EncodeError> {
78        let offset = 0;
79        let offset = {
80            let o = codec::values::encode_u8(encode_packet_type(self.r#type, self.flags), &mut bytes[offset..])?;
81            offset + o
82        };
83        let offset = {
84            let mut remaining_length = [0u8; 4];
85            let o = encode_remaining_length(self.len, &mut remaining_length);
86            (&mut bytes[offset..offset+o]).copy_from_slice(&remaining_length[..o]);
87            offset + o
88        };
89        Ok(offset)
90    }
91}
92
93fn parse_remaining_length(bytes: &[u8]) -> Result<Status<(usize, u32)>, DecodeError> {
94    let mut multiplier = 1;
95    let mut value = 0u32;
96    let mut index = 0;
97
98    loop {
99        if multiplier > 128 * 128 * 128 {
100            return Err(DecodeError::RemainingLength);
101        }
102
103        if index >= bytes.len() {
104            return Ok(Status::Partial(1));
105        }
106
107        let byte = bytes[index];
108        index += 1;
109
110        value += (byte & 0b01111111) as u32 * multiplier;
111
112        multiplier *= 128;
113
114        if byte & 128 == 0 {
115            return Ok(Status::Complete((index, value)));
116        }
117    }
118}
119
120fn encode_remaining_length(mut len: u32, buf: &mut [u8; 4]) -> usize {
121    let mut index = 0;
122    loop {
123        let mut byte = len as u8 % 128;
124        len /= 128;
125        if len > 0 {
126            byte |= 128;
127        }
128        buf[index] = byte;
129        index = index + 1;
130
131        if len == 0 {
132            break index;
133        }
134    }
135}
136
137fn parse_packet_type(inp: u8) -> Result<(PacketType, PacketFlags), DecodeError> {
138    // high 4 bits are the packet type
139    let packet_type = match (inp & 0xF0) >> 4 {
140        1 => PacketType::Connect,
141        2 => PacketType::Connack,
142        3 => PacketType::Publish,
143        4 => PacketType::Puback,
144        5 => PacketType::Pubrec,
145        6 => PacketType::Pubrel,
146        7 => PacketType::Pubcomp,
147        8 => PacketType::Subscribe,
148        9 => PacketType::Suback,
149        10 => PacketType::Unsubscribe,
150        11 => PacketType::Unsuback,
151        12 => PacketType::Pingreq,
152        13 => PacketType::Pingresp,
153        14 => PacketType::Disconnect,
154        _ => return Err(DecodeError::PacketType),
155    };
156
157    // low 4 bits represent control flags
158    let flags = PacketFlags(inp & 0xF);
159
160    validate_flag(packet_type, flags)
161}
162
163fn encode_packet_type(r#type: PacketType, flags: PacketFlags) -> u8 {
164    let packet_type: u8 = match r#type {
165        PacketType::Connect => 1,
166        PacketType::Connack => 2,
167        PacketType::Publish => 3,
168        PacketType::Puback => 4,
169        PacketType::Pubrec => 5,
170        PacketType::Pubrel => 6,
171        PacketType::Pubcomp => 7,
172        PacketType::Subscribe => 8,
173        PacketType::Suback => 9,
174        PacketType::Unsubscribe => 10,
175        PacketType::Unsuback => 11,
176        PacketType::Pingreq => 12,
177        PacketType::Pingresp => 13,
178        PacketType::Disconnect => 14,
179    };
180
181    (packet_type << 4) | flags.0
182}
183
184fn validate_flag(packet_type: PacketType, flags: PacketFlags) -> Result<(PacketType, PacketFlags), DecodeError> {
185    // for the following packet types, the control flag MUST be zero
186    const ZERO_TYPES: &[PacketType] = &[
187        PacketType::Connect,
188        PacketType::Connack,
189        PacketType::Puback,
190        PacketType::Pubrec,
191        PacketType::Pubcomp,
192        PacketType::Suback,
193        PacketType::Unsuback,
194        PacketType::Pingreq,
195        PacketType::Pingresp,
196        PacketType::Disconnect,
197    ];
198    // for the following packet types, the control flag MUST be 0b0010
199    const ONE_TYPES: &[PacketType] = &[
200        PacketType::Pubrel,
201        PacketType::Subscribe,
202        PacketType::Unsubscribe,
203    ];
204
205    validate_flag_val(packet_type, flags, ZERO_TYPES, PacketFlags(0b0000))
206        .and_then(|_| validate_flag_val(packet_type, flags, ONE_TYPES, PacketFlags(0b0010)))
207}
208
209fn validate_flag_val(
210    packet_type: PacketType,
211    flags: PacketFlags,
212    types: &[PacketType],
213    expected_flags: PacketFlags,
214) -> Result<(PacketType, PacketFlags), DecodeError> {
215    if let Some(_) = types.iter().find(|&&v| v == packet_type) {
216        if flags != expected_flags {
217            return Err(DecodeError::PacketFlag);
218        }
219    }
220
221    Ok((packet_type, flags))
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use rayon::prelude::*;
228    use std::format;
229
230    #[test]
231    fn packet_type() {
232        let mut inputs: [([u8; 1], PacketType); 14] = [
233            ([01 << 4 | 0b0000], PacketType::Connect),
234            ([02 << 4 | 0b0000], PacketType::Connack),
235            ([03 << 4 | 0b0000], PacketType::Publish),
236            ([04 << 4 | 0b0000], PacketType::Puback),
237            ([05 << 4 | 0b0000], PacketType::Pubrec),
238            ([06 << 4 | 0b0010], PacketType::Pubrel),
239            ([07 << 4 | 0b0000], PacketType::Pubcomp),
240            ([08 << 4 | 0b0010], PacketType::Subscribe),
241            ([09 << 4 | 0b0000], PacketType::Suback),
242            ([10 << 4 | 0b0010], PacketType::Unsubscribe),
243            ([11 << 4 | 0b0000], PacketType::Unsuback),
244            ([12 << 4 | 0b0000], PacketType::Pingreq),
245            ([13 << 4 | 0b0000], PacketType::Pingresp),
246            ([14 << 4 | 0b0000], PacketType::Disconnect),
247        ];
248
249        for (buf, expected_type) in inputs.iter_mut() {
250            let expected_flag = PacketFlags(buf[0] & 0xF);
251            let (packet_type, flag) = parse_packet_type(buf[0]).unwrap();
252            assert_eq!(packet_type, *expected_type);
253            assert_eq!(flag, expected_flag);
254        }
255    }
256
257    #[test]
258    fn bad_packet_type() {
259        let result = parse_packet_type(15 << 4);
260        assert_eq!(result, Err(DecodeError::PacketType));
261    }
262
263    #[test]
264    fn bad_zero_flags() {
265        let mut inputs: [([u8; 1], PacketType); 10] = [
266            ([01 << 4 | 1], PacketType::Connect),
267            ([02 << 4 | 1], PacketType::Connack),
268            ([04 << 4 | 1], PacketType::Puback),
269            ([05 << 4 | 1], PacketType::Pubrec),
270            ([07 << 4 | 1], PacketType::Pubcomp),
271            ([09 << 4 | 1], PacketType::Suback),
272            ([11 << 4 | 1], PacketType::Unsuback),
273            ([12 << 4 | 1], PacketType::Pingreq),
274            ([13 << 4 | 1], PacketType::Pingresp),
275            ([14 << 4 | 1], PacketType::Disconnect),
276        ];
277        for (buf, _) in inputs.iter_mut() {
278            let result = parse_packet_type(buf[0]);
279            assert_eq!(result, Err(DecodeError::PacketFlag));
280        }
281    }
282
283    #[test]
284    fn bad_one_flags() {
285        let mut inputs: [([u8; 1], PacketType); 3] = [
286            ([06 << 4 | 0], PacketType::Pubrel),
287            ([08 << 4 | 0], PacketType::Subscribe),
288            ([10 << 4 | 0], PacketType::Unsubscribe),
289        ];
290        for (buf, _) in inputs.iter_mut() {
291            let result = parse_packet_type(buf[0]);
292            assert_eq!(result, Err(DecodeError::PacketFlag));
293        }
294    }
295
296    #[test]
297    fn publish_flags() {
298        for i in 0..15 {
299            let input = 03 << 4 | i;
300            let (packet_type, flag) = parse_packet_type(input).unwrap();
301            assert_eq!(packet_type, PacketType::Publish);
302            assert_eq!(flag, PacketFlags(i));
303        }
304    }
305
306    #[test]
307    #[ignore]
308    fn remaining_length() {
309        // NOTE: This test can take a while to complete.
310        let _: u32 = (0u32..(268435455 + 1))
311            .into_par_iter()
312            .map(|i| {
313                let mut buf = [0u8; 4];
314                let expected_offset = encode_remaining_length(i, &mut buf);
315                let (offset, len) =
316                    parse_remaining_length(&buf).expect(&format!("Failed for number: {}", i)).unwrap();
317                assert_eq!(i, len);
318                assert_eq!(expected_offset, offset);
319                0
320            })
321            .sum();
322    }
323
324    #[test]
325    fn bad_remaining_length() {
326        let buf = [0xFF, 0xFF, 0xFF, 0xFF];
327        let result = parse_remaining_length(&buf);
328        assert_eq!(result, Err(DecodeError::RemainingLength));
329    }
330
331    #[test]
332    fn bad_remaining_length2() {
333        let buf = [0xFF, 0xFF];
334        let result = parse_remaining_length(&buf);
335        assert_eq!(result, Ok(Status::Partial(1)));
336    }
337
338    #[test]
339    fn fixed_header1() {
340        let buf = [
341            01 << 4 | 0b0000, // PacketType::Connect
342            0,                // remaining length
343        ];
344        let (offset, header) = FixedHeader::decode(&buf).unwrap().unwrap();
345        assert_eq!(offset, 2);
346        assert_eq!(header.r#type(), PacketType::Connect);
347        assert_eq!(header.flags(), PacketFlags(0));
348        assert_eq!(header.len(), 0);
349    }
350
351    #[test]
352    fn fixed_header2() {
353        let buf = [
354            03 << 4 | 0b0000, // PacketType::Publish
355            0x80,             // remaining length
356            0x80,
357            0x80,
358            0x1,
359        ];
360        let (offset, header) = FixedHeader::decode(&buf).unwrap().unwrap();
361        assert_eq!(offset, 5);
362        assert_eq!(header.r#type(), PacketType::Publish);
363        assert_eq!(header.flags(), PacketFlags(0));
364        assert_eq!(header.len(), 2097152);
365    }
366
367    #[test]
368    fn bad_len() {
369        let buf = [03 << 4 | 0];
370        let result = FixedHeader::decode(&buf);
371        assert_eq!(result, Ok(Status::Partial(1)));
372    }
373}