mqtt/control/
fixed_header.rs1use std::io::{self, Read, Write};
4
5use byteorder::{ReadBytesExt, WriteBytesExt};
6
7#[cfg(feature = "tokio")]
8use tokio::io::{AsyncRead, AsyncReadExt};
9
10use crate::control::packet_type::{PacketType, PacketTypeError};
11use crate::{Decodable, Encodable};
12
13#[derive(Debug, Clone, Copy, Eq, PartialEq)]
26pub struct FixedHeader {
27 pub packet_type: PacketType,
29
30 pub remaining_length: u32,
34}
35
36impl FixedHeader {
37 pub fn new(packet_type: PacketType, remaining_length: u32) -> FixedHeader {
38 debug_assert!(remaining_length <= 0x0FFF_FFFF);
39 FixedHeader {
40 packet_type,
41 remaining_length,
42 }
43 }
44
45 #[cfg(feature = "tokio")]
46 pub async fn parse<A: AsyncRead + Unpin>(rdr: &mut A) -> Result<Self, FixedHeaderError> {
51 let type_val = rdr.read_u8().await?;
52
53 let mut remaining_len = 0;
54 let mut i = 0;
55
56 loop {
57 let byte = rdr.read_u8().await?;
58
59 remaining_len |= (u32::from(byte) & 0x7F) << (7 * i);
60
61 if i >= 4 {
62 return Err(FixedHeaderError::MalformedRemainingLength);
63 }
64
65 if byte & 0x80 == 0 {
66 break;
67 } else {
68 i += 1;
69 }
70 }
71
72 match PacketType::from_u8(type_val) {
73 Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)),
74 Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)),
75 Err(err) => Err(From::from(err)),
76 }
77 }
78}
79
80impl Encodable for FixedHeader {
81 fn encode<W: Write>(&self, wr: &mut W) -> Result<(), io::Error> {
82 wr.write_u8(self.packet_type.to_u8())?;
83
84 let mut cur_len = self.remaining_length;
85 loop {
86 let mut byte = (cur_len & 0x7F) as u8;
87 cur_len >>= 7;
88
89 if cur_len > 0 {
90 byte |= 0x80;
91 }
92
93 wr.write_u8(byte)?;
94
95 if cur_len == 0 {
96 break;
97 }
98 }
99
100 Ok(())
101 }
102
103 fn encoded_length(&self) -> u32 {
104 let rem_size = if self.remaining_length >= 2_097_152 {
105 4
106 } else if self.remaining_length >= 16_384 {
107 3
108 } else if self.remaining_length >= 128 {
109 2
110 } else {
111 1
112 };
113 1 + rem_size
114 }
115}
116
117impl Decodable for FixedHeader {
118 type Error = FixedHeaderError;
119 type Cond = ();
120
121 fn decode_with<R: Read>(rdr: &mut R, _rest: ()) -> Result<FixedHeader, FixedHeaderError> {
122 let type_val = rdr.read_u8()?;
123 let remaining_len = {
124 let mut cur = 0u32;
125 for i in 0.. {
126 let byte = rdr.read_u8()?;
127 cur |= ((byte as u32) & 0x7F) << (7 * i);
128
129 if i >= 4 {
130 return Err(FixedHeaderError::MalformedRemainingLength);
131 }
132
133 if byte & 0x80 == 0 {
134 break;
135 }
136 }
137
138 cur
139 };
140
141 match PacketType::from_u8(type_val) {
142 Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)),
143 Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)),
144 Err(err) => Err(From::from(err)),
145 }
146 }
147}
148
149#[derive(Debug, thiserror::Error)]
150pub enum FixedHeaderError {
151 #[error("malformed remaining length")]
152 MalformedRemainingLength,
153 #[error("reserved header ({0}, {1})")]
154 ReservedType(u8, u32),
155 #[error(transparent)]
156 PacketTypeError(#[from] PacketTypeError),
157 #[error(transparent)]
158 IoError(#[from] io::Error),
159}
160
161#[cfg(test)]
162mod test {
163 use super::*;
164
165 use crate::control::packet_type::{ControlType, PacketType};
166 use crate::{Decodable, Encodable};
167 use std::io::Cursor;
168
169 #[test]
170 fn test_encode_fixed_header() {
171 let header = FixedHeader::new(PacketType::with_default(ControlType::Connect), 321);
172 let mut buf = Vec::new();
173 header.encode(&mut buf).unwrap();
174
175 let expected = b"\x10\xc1\x02";
176 assert_eq!(&expected[..], &buf[..]);
177 }
178
179 #[test]
180 fn test_decode_fixed_header() {
181 let stream = b"\x10\xc1\x02";
182 let mut cursor = Cursor::new(&stream[..]);
183 let header = FixedHeader::decode(&mut cursor).unwrap();
184 assert_eq!(header.packet_type, PacketType::with_default(ControlType::Connect));
185 assert_eq!(header.remaining_length, 321);
186 }
187
188 #[test]
189 #[should_panic]
190 fn test_decode_too_long_fixed_header() {
191 let stream = b"\x10\x80\x80\x80\x80\x02";
192 let mut cursor = Cursor::new(&stream[..]);
193 FixedHeader::decode(&mut cursor).unwrap();
194 }
195}