mqtt/control/
fixed_header.rs

1//! Fixed header in MQTT
2
3use std::io::{self, Read, Write};
4
5use byteorder::{ReadBytesExt, WriteBytesExt};
6
7#[cfg(feature = "tokio")]
8use tokio::io::{AsyncRead, AsyncReadExt};
9
10use crate::control::packet_type::{PacketType, PacketTypeError};
11use crate::{Decodable, Encodable};
12
13/// Fixed header for each MQTT control packet
14///
15/// Format:
16///
17/// ```plain
18/// 7                          3                          0
19/// +--------------------------+--------------------------+
20/// | MQTT Control Packet Type | Flags for each type      |
21/// +--------------------------+--------------------------+
22/// | Remaining Length ...                                |
23/// +-----------------------------------------------------+
24/// ```
25#[derive(Debug, Clone, Copy, Eq, PartialEq)]
26pub struct FixedHeader {
27    /// Packet Type
28    pub packet_type: PacketType,
29
30    /// The Remaining Length is the number of bytes remaining within the current packet,
31    /// including data in the variable header and the payload. The Remaining Length does
32    /// not include the bytes used to encode the Remaining Length.
33    pub remaining_length: u32,
34}
35
36impl FixedHeader {
37    pub fn new(packet_type: PacketType, remaining_length: u32) -> FixedHeader {
38        debug_assert!(remaining_length <= 0x0FFF_FFFF);
39        FixedHeader {
40            packet_type,
41            remaining_length,
42        }
43    }
44
45    #[cfg(feature = "tokio")]
46    /// Asynchronously parse a single fixed header from an AsyncRead type, such as a network
47    /// socket.
48    ///
49    /// This requires mqtt-rs to be built with `feature = "tokio"`
50    pub async fn parse<A: AsyncRead + Unpin>(rdr: &mut A) -> Result<Self, FixedHeaderError> {
51        let type_val = rdr.read_u8().await?;
52
53        let mut remaining_len = 0;
54        let mut i = 0;
55
56        loop {
57            let byte = rdr.read_u8().await?;
58
59            remaining_len |= (u32::from(byte) & 0x7F) << (7 * i);
60
61            if i >= 4 {
62                return Err(FixedHeaderError::MalformedRemainingLength);
63            }
64
65            if byte & 0x80 == 0 {
66                break;
67            } else {
68                i += 1;
69            }
70        }
71
72        match PacketType::from_u8(type_val) {
73            Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)),
74            Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)),
75            Err(err) => Err(From::from(err)),
76        }
77    }
78}
79
80impl Encodable for FixedHeader {
81    fn encode<W: Write>(&self, wr: &mut W) -> Result<(), io::Error> {
82        wr.write_u8(self.packet_type.to_u8())?;
83
84        let mut cur_len = self.remaining_length;
85        loop {
86            let mut byte = (cur_len & 0x7F) as u8;
87            cur_len >>= 7;
88
89            if cur_len > 0 {
90                byte |= 0x80;
91            }
92
93            wr.write_u8(byte)?;
94
95            if cur_len == 0 {
96                break;
97            }
98        }
99
100        Ok(())
101    }
102
103    fn encoded_length(&self) -> u32 {
104        let rem_size = if self.remaining_length >= 2_097_152 {
105            4
106        } else if self.remaining_length >= 16_384 {
107            3
108        } else if self.remaining_length >= 128 {
109            2
110        } else {
111            1
112        };
113        1 + rem_size
114    }
115}
116
117impl Decodable for FixedHeader {
118    type Error = FixedHeaderError;
119    type Cond = ();
120
121    fn decode_with<R: Read>(rdr: &mut R, _rest: ()) -> Result<FixedHeader, FixedHeaderError> {
122        let type_val = rdr.read_u8()?;
123        let remaining_len = {
124            let mut cur = 0u32;
125            for i in 0.. {
126                let byte = rdr.read_u8()?;
127                cur |= ((byte as u32) & 0x7F) << (7 * i);
128
129                if i >= 4 {
130                    return Err(FixedHeaderError::MalformedRemainingLength);
131                }
132
133                if byte & 0x80 == 0 {
134                    break;
135                }
136            }
137
138            cur
139        };
140
141        match PacketType::from_u8(type_val) {
142            Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)),
143            Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)),
144            Err(err) => Err(From::from(err)),
145        }
146    }
147}
148
149#[derive(Debug, thiserror::Error)]
150pub enum FixedHeaderError {
151    #[error("malformed remaining length")]
152    MalformedRemainingLength,
153    #[error("reserved header ({0}, {1})")]
154    ReservedType(u8, u32),
155    #[error(transparent)]
156    PacketTypeError(#[from] PacketTypeError),
157    #[error(transparent)]
158    IoError(#[from] io::Error),
159}
160
161#[cfg(test)]
162mod test {
163    use super::*;
164
165    use crate::control::packet_type::{ControlType, PacketType};
166    use crate::{Decodable, Encodable};
167    use std::io::Cursor;
168
169    #[test]
170    fn test_encode_fixed_header() {
171        let header = FixedHeader::new(PacketType::with_default(ControlType::Connect), 321);
172        let mut buf = Vec::new();
173        header.encode(&mut buf).unwrap();
174
175        let expected = b"\x10\xc1\x02";
176        assert_eq!(&expected[..], &buf[..]);
177    }
178
179    #[test]
180    fn test_decode_fixed_header() {
181        let stream = b"\x10\xc1\x02";
182        let mut cursor = Cursor::new(&stream[..]);
183        let header = FixedHeader::decode(&mut cursor).unwrap();
184        assert_eq!(header.packet_type, PacketType::with_default(ControlType::Connect));
185        assert_eq!(header.remaining_length, 321);
186    }
187
188    #[test]
189    #[should_panic]
190    fn test_decode_too_long_fixed_header() {
191        let stream = b"\x10\x80\x80\x80\x80\x02";
192        let mut cursor = Cursor::new(&stream[..]);
193        FixedHeader::decode(&mut cursor).unwrap();
194    }
195}