1use bytes::{Buf, BufMut, BytesMut};
2use std::io::Cursor;
3use tokio_util::codec::{Decoder, Encoder};
4
5use crate::error::MqttError;
6use crate::packet::{
7 Connect, MqttPacket, Property, ProtocolLevel, PubAck, PubComp, PubRec, PubRel, Publish, SubAck,
8 Subscribe, UnsubAck, Unsubscribe,
9};
10use crate::utils::read_var_int;
11
12pub struct MqttCodec {
13 pub protocol_level: ProtocolLevel,
16 pub max_packet_size: Option<usize>,
21}
22
23impl Default for MqttCodec {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl MqttCodec {
30 pub fn new() -> Self {
31 Self {
32 protocol_level: ProtocolLevel::V311,
33 max_packet_size: None,
34 }
35 }
36
37 pub fn with_max_packet_size(max_size: usize) -> Self {
42 Self {
43 protocol_level: ProtocolLevel::V311,
44 max_packet_size: Some(max_size),
45 }
46 }
47}
48
49impl Decoder for MqttCodec {
50 type Item = MqttPacket;
51 type Error = MqttError;
52
53 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
54 if src.is_empty() {
55 return Ok(None);
56 }
57
58 let mut cursor = Cursor::new(&src[..]);
59 let fixed_header = cursor.get_u8();
60 let packet_type = fixed_header >> 4;
61 let flags = fixed_header & 0x0F;
62
63 let var_int_result = read_var_int(&mut cursor)?;
64 let remaining_length = match var_int_result {
65 Some((len, _)) => len as usize,
66 None => return Ok(None), };
68
69 let header_len = cursor.position() as usize;
70 let total_len = header_len + remaining_length;
71
72 if let Some(max) = self.max_packet_size {
75 if remaining_length > max {
76 if src.len() >= total_len {
79 src.advance(total_len);
80 }
81 return Err(MqttError::PayloadTooLarge {
82 size: remaining_length,
83 limit: max,
84 });
85 }
86 }
87
88 if src.len() < total_len {
89 src.reserve(total_len - src.len());
90 return Ok(None); }
92
93 let packet_bytes = src.split_to(total_len).freeze();
95
96 let mut payload_cursor = Cursor::new(&packet_bytes[header_len..]);
97
98 let packet = match packet_type {
99 1 => {
100 let protocol_name_len = payload_cursor.get_u16() as usize;
102 let mut protocol_name = vec![0; protocol_name_len];
103 payload_cursor.copy_to_slice(&mut protocol_name);
104
105 let protocol_level_byte = payload_cursor.get_u8();
106 let protocol_level = match protocol_level_byte {
107 4 => ProtocolLevel::V311,
108 5 => ProtocolLevel::V5,
109 _ => return Err(MqttError::UnsupportedVersion),
110 };
111
112 let connect_flags = payload_cursor.get_u8();
113 let clean_session = (connect_flags & 0x02) != 0;
114 let keep_alive = payload_cursor.get_u16();
115
116 if protocol_level == ProtocolLevel::V5 {
118 if let Some((props_len, _)) = read_var_int(&mut payload_cursor)? {
119 payload_cursor.advance(props_len as usize); } else {
121 return Err(MqttError::MalformedPacket("Incomplete v5 properties"));
122 }
123 }
124
125 let client_id_len = payload_cursor.get_u16() as usize;
127 let mut client_id_bytes = vec![0; client_id_len];
128 payload_cursor.copy_to_slice(&mut client_id_bytes);
129 let client_id = String::from_utf8_lossy(&client_id_bytes).to_string();
130
131 self.protocol_level = protocol_level;
133
134 MqttPacket::Connect(Connect {
135 protocol_level,
136 client_id,
137 clean_session,
138 keep_alive,
139 })
140 }
141 3 => {
142 let dup = (flags & 0x08) != 0;
144 let qos = (flags & 0x06) >> 1;
145 let retain = (flags & 0x01) != 0;
146
147 let topic_len = payload_cursor.get_u16() as usize;
148 let mut topic_bytes = vec![0; topic_len];
149 payload_cursor.copy_to_slice(&mut topic_bytes);
150 let topic = String::from_utf8_lossy(&topic_bytes).to_string();
151
152 let packet_id = if qos > 0 {
153 Some(payload_cursor.get_u16())
154 } else {
155 None
156 };
157
158 let mut properties = Vec::new();
159
160 if self.protocol_level == ProtocolLevel::V5 {
162 if let Some((props_len, _)) = read_var_int(&mut payload_cursor)? {
163 let props_end = payload_cursor.position() as usize + props_len as usize;
164 if total_len < header_len + props_end {
165 return Err(MqttError::MalformedPacket(
166 "Properties length exceeds packet",
167 ));
168 }
169 properties = parse_properties(&mut payload_cursor, props_len as usize)?;
170 } else {
171 return Err(MqttError::MalformedPacket(
172 "Incomplete v5 properties in PUBLISH",
173 ));
174 }
175 }
176
177 let payload_start = header_len + payload_cursor.position() as usize;
179 let payload = packet_bytes.slice(payload_start..total_len);
180
181 MqttPacket::Publish(Publish {
182 dup,
183 qos,
184 retain,
185 topic,
186 packet_id,
187 properties,
188 payload,
189 })
190 }
191 4 => {
192 let packet_id = payload_cursor.get_u16();
193 let reason_code =
195 if self.protocol_level == ProtocolLevel::V5 && remaining_length > 2 {
196 Some(payload_cursor.get_u8())
197 } else {
198 None
199 };
200 MqttPacket::PubAck(PubAck {
201 packet_id,
202 reason_code,
203 })
204 }
205 5 => MqttPacket::PubRec(PubRec {
206 packet_id: payload_cursor.get_u16(),
207 }),
208 6 => MqttPacket::PubRel(PubRel {
209 packet_id: payload_cursor.get_u16(),
210 }),
211 7 => MqttPacket::PubComp(PubComp {
212 packet_id: payload_cursor.get_u16(),
213 }),
214 8 => {
215 let packet_id = payload_cursor.get_u16();
217 let mut filters = Vec::new();
218 while payload_cursor.has_remaining() {
219 let topic_len = payload_cursor.get_u16() as usize;
220 let mut topic_bytes = vec![0; topic_len];
221 payload_cursor.copy_to_slice(&mut topic_bytes);
222 let topic = String::from_utf8_lossy(&topic_bytes).to_string();
223 let qos = payload_cursor.get_u8();
224 filters.push((topic, qos));
225 }
226 MqttPacket::Subscribe(Subscribe { packet_id, filters })
227 }
228 9 => {
229 let packet_id = payload_cursor.get_u16();
231 let mut return_codes = Vec::new();
232 while payload_cursor.has_remaining() {
233 return_codes.push(payload_cursor.get_u8());
234 }
235 MqttPacket::SubAck(SubAck {
236 packet_id,
237 return_codes,
238 })
239 }
240 10 => {
241 let packet_id = payload_cursor.get_u16();
243 let mut filters = Vec::new();
244 while payload_cursor.has_remaining() {
245 let topic_len = payload_cursor.get_u16() as usize;
246 let mut topic_bytes = vec![0; topic_len];
247 payload_cursor.copy_to_slice(&mut topic_bytes);
248 filters.push(String::from_utf8_lossy(&topic_bytes).to_string());
249 }
250 MqttPacket::Unsubscribe(Unsubscribe { packet_id, filters })
251 }
252 11 => MqttPacket::UnsubAck(UnsubAck {
253 packet_id: payload_cursor.get_u16(),
254 }),
255 12 => MqttPacket::PingReq,
256 13 => MqttPacket::PingResp,
257 14 => MqttPacket::Disconnect,
258 _ => {
259 return Err(MqttError::ProtocolError(format!(
260 "Unsupported packet type: {}",
261 packet_type
262 )))
263 }
264 };
265
266 Ok(Some(packet))
267 }
268}
269
270impl Encoder<MqttPacket> for MqttCodec {
271 type Error = MqttError;
272
273 fn encode(&mut self, item: MqttPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
274 match item {
275 MqttPacket::ConnAck(connack) => {
276 dst.put_u8(0x20); dst.put_u8(2); dst.put_u8(if connack.session_present { 1 } else { 0 });
279 dst.put_u8(connack.return_code);
280 }
281 MqttPacket::PingResp => {
282 dst.put_u8(0xD0); dst.put_u8(0); }
285 MqttPacket::PubAck(puback) => {
286 dst.put_u8(0x40);
287 if self.protocol_level == ProtocolLevel::V5 {
288 let reason = puback.reason_code.unwrap_or(0x00);
289 if reason == 0x00 {
290 dst.put_u8(2); dst.put_u16(puback.packet_id);
295 } else {
296 dst.put_u8(4); dst.put_u16(puback.packet_id);
299 dst.put_u8(reason);
300 dst.put_u8(0); }
302 } else {
303 dst.put_u8(2); dst.put_u16(puback.packet_id);
306 }
307 }
308 MqttPacket::PubRec(pubrec) => {
309 dst.put_u8(0x50);
310 dst.put_u8(2);
311 dst.put_u16(pubrec.packet_id);
312 }
313 MqttPacket::PubRel(pubrel) => {
314 dst.put_u8(0x62);
315 dst.put_u8(2);
316 dst.put_u16(pubrel.packet_id);
317 }
318 MqttPacket::PubComp(pubcomp) => {
319 dst.put_u8(0x70);
320 dst.put_u8(2);
321 dst.put_u16(pubcomp.packet_id);
322 }
323 MqttPacket::SubAck(suback) => {
324 dst.put_u8(0x90);
325 let props_len = if self.protocol_level == ProtocolLevel::V5 {
327 1
328 } else {
329 0
330 };
331 let remaining_len = 2 + suback.return_codes.len() as u32 + props_len;
332 crate::utils::write_var_int(remaining_len, dst)?;
333 dst.put_u16(suback.packet_id);
334
335 if self.protocol_level == ProtocolLevel::V5 {
336 dst.put_u8(0); }
338
339 for rc in suback.return_codes {
340 dst.put_u8(rc);
341 }
342 }
343 MqttPacket::UnsubAck(unsuback) => {
344 dst.put_u8(0xB0);
345 dst.put_u8(2);
346 dst.put_u16(unsuback.packet_id);
347 }
348 MqttPacket::PingReq => {
349 dst.put_u8(0xC0);
350 dst.put_u8(0);
351 }
352
353 MqttPacket::Disconnect => {
354 dst.put_u8(0xE0);
355 dst.put_u8(0);
356 }
357 _ => {
358 return Err(MqttError::ProtocolError(
359 "Packet encoding not implemented for this type".into(),
360 ))
361 }
362 }
363 Ok(())
364 }
365}
366
367pub fn parse_properties(
368 cursor: &mut Cursor<&[u8]>,
369 length: usize,
370) -> Result<Vec<Property>, MqttError> {
371 let mut properties = Vec::new();
372 let start_pos = cursor.position() as usize;
373
374 while (cursor.position() as usize - start_pos) < length {
375 if let Some((identifier, _)) = read_var_int(cursor)? {
376 match identifier {
377 0x01 => properties.push(Property::PayloadFormatIndicator(cursor.get_u8())),
378 0x02 => properties.push(Property::MessageExpiryInterval(cursor.get_u32())),
379 0x03 => {
380 let str_len = cursor.get_u16() as usize;
381 let mut str_bytes = vec![0; str_len];
382 cursor.copy_to_slice(&mut str_bytes);
383 properties.push(Property::ContentType(
384 String::from_utf8_lossy(&str_bytes).to_string(),
385 ));
386 }
387 0x08 => {
388 let str_len = cursor.get_u16() as usize;
389 let mut str_bytes = vec![0; str_len];
390 cursor.copy_to_slice(&mut str_bytes);
391 properties.push(Property::ResponseTopic(
392 String::from_utf8_lossy(&str_bytes).to_string(),
393 ));
394 }
395 0x09 => {
396 let bin_len = cursor.get_u16() as usize;
397 let mut bin_bytes = vec![0; bin_len];
398 cursor.copy_to_slice(&mut bin_bytes);
399 properties.push(Property::CorrelationData(bin_bytes));
400 }
401 0x0B => {
402 if let Some((sub_id, _)) = read_var_int(cursor)? {
403 properties.push(Property::SubscriptionIdentifier(sub_id));
404 }
405 }
406 0x23 => properties.push(Property::TopicAlias(cursor.get_u16())),
407 0x26 => {
408 let k_len = cursor.get_u16() as usize;
409 let mut k_bytes = vec![0; k_len];
410 cursor.copy_to_slice(&mut k_bytes);
411 let v_len = cursor.get_u16() as usize;
412 let mut v_bytes = vec![0; v_len];
413 cursor.copy_to_slice(&mut v_bytes);
414 properties.push(Property::UserProperty(
415 String::from_utf8_lossy(&k_bytes).to_string(),
416 String::from_utf8_lossy(&v_bytes).to_string(),
417 ));
418 }
419 _ => return Err(MqttError::MalformedPacket("Unknown property identifier")),
420 }
421 } else {
422 break;
423 }
424 }
425
426 Ok(properties)
427}