Skip to main content

mqtt_frame/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use std::io::Cursor;
3use tokio_util::codec::{Decoder, Encoder};
4
5use crate::error::MqttError;
6use crate::packet::{
7    Connect, MqttPacket, Property, ProtocolLevel, PubAck, PubComp, PubRec, PubRel, Publish, SubAck,
8    Subscribe, UnsubAck, Unsubscribe,
9};
10use crate::utils::read_var_int;
11
12pub struct MqttCodec {
13    /// Tracks the protocol level negotiated during CONNECT.
14    /// Defaults to V311 for the first packet.
15    pub protocol_level: ProtocolLevel,
16    /// Maximum allowed remaining length (bytes) for any MQTT packet.
17    /// `None` = no limit (backward-compatible default).
18    /// When exceeded, `decode()` returns `MqttError::PayloadTooLarge`
19    /// **before** allocating memory for the packet body.
20    pub max_packet_size: Option<usize>,
21}
22
23impl Default for MqttCodec {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl MqttCodec {
30    pub fn new() -> Self {
31        Self {
32            protocol_level: ProtocolLevel::V311,
33            max_packet_size: None,
34        }
35    }
36
37    /// Create a codec with a maximum packet size limit.
38    ///
39    /// Any packet whose `remaining_length` exceeds `max_size` is rejected
40    /// at the decode level before memory is allocated.
41    pub fn with_max_packet_size(max_size: usize) -> Self {
42        Self {
43            protocol_level: ProtocolLevel::V311,
44            max_packet_size: Some(max_size),
45        }
46    }
47}
48
49impl Decoder for MqttCodec {
50    type Item = MqttPacket;
51    type Error = MqttError;
52
53    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
54        if src.is_empty() {
55            return Ok(None);
56        }
57
58        let mut cursor = Cursor::new(&src[..]);
59        let fixed_header = cursor.get_u8();
60        let packet_type = fixed_header >> 4;
61        let flags = fixed_header & 0x0F;
62
63        let var_int_result = read_var_int(&mut cursor)?;
64        let remaining_length = match var_int_result {
65            Some((len, _)) => len as usize,
66            None => return Ok(None), // Not enough data for length
67        };
68
69        let header_len = cursor.position() as usize;
70        let total_len = header_len + remaining_length;
71
72        // Enforce max packet size BEFORE allocating memory.
73        // This prevents a single oversized PUBLISH from exhausting server RAM.
74        if let Some(max) = self.max_packet_size {
75            if remaining_length > max {
76                // If the oversized packet is already fully buffered, drain it
77                // to prevent the decoder from re-reading the same bytes forever.
78                if src.len() >= total_len {
79                    src.advance(total_len);
80                }
81                return Err(MqttError::PayloadTooLarge {
82                    size: remaining_length,
83                    limit: max,
84                });
85            }
86        }
87
88        if src.len() < total_len {
89            src.reserve(total_len - src.len());
90            return Ok(None); // Wait for more data
91        }
92
93        // We have the full packet. Let's slice it out using zero-copy.
94        let packet_bytes = src.split_to(total_len).freeze();
95
96        let mut payload_cursor = Cursor::new(&packet_bytes[header_len..]);
97
98        let packet = match packet_type {
99            1 => {
100                // CONNECT
101                let protocol_name_len = payload_cursor.get_u16() as usize;
102                let mut protocol_name = vec![0; protocol_name_len];
103                payload_cursor.copy_to_slice(&mut protocol_name);
104
105                let protocol_level_byte = payload_cursor.get_u8();
106                let protocol_level = match protocol_level_byte {
107                    4 => ProtocolLevel::V311,
108                    5 => ProtocolLevel::V5,
109                    _ => return Err(MqttError::UnsupportedVersion),
110                };
111
112                let connect_flags = payload_cursor.get_u8();
113                let clean_session = (connect_flags & 0x02) != 0;
114                let keep_alive = payload_cursor.get_u16();
115
116                // Properties (if v5)
117                if protocol_level == ProtocolLevel::V5 {
118                    if let Some((props_len, _)) = read_var_int(&mut payload_cursor)? {
119                        payload_cursor.advance(props_len as usize); // Skip properties for now
120                    } else {
121                        return Err(MqttError::MalformedPacket("Incomplete v5 properties"));
122                    }
123                }
124
125                // Client ID
126                let client_id_len = payload_cursor.get_u16() as usize;
127                let mut client_id_bytes = vec![0; client_id_len];
128                payload_cursor.copy_to_slice(&mut client_id_bytes);
129                let client_id = String::from_utf8_lossy(&client_id_bytes).to_string();
130
131                // Self-update the codec's protocol level for subsequent packets!
132                self.protocol_level = protocol_level;
133
134                MqttPacket::Connect(Connect {
135                    protocol_level,
136                    client_id,
137                    clean_session,
138                    keep_alive,
139                })
140            }
141            3 => {
142                // PUBLISH
143                let dup = (flags & 0x08) != 0;
144                let qos = (flags & 0x06) >> 1;
145                let retain = (flags & 0x01) != 0;
146
147                let topic_len = payload_cursor.get_u16() as usize;
148                let mut topic_bytes = vec![0; topic_len];
149                payload_cursor.copy_to_slice(&mut topic_bytes);
150                let topic = String::from_utf8_lossy(&topic_bytes).to_string();
151
152                let packet_id = if qos > 0 {
153                    Some(payload_cursor.get_u16())
154                } else {
155                    None
156                };
157
158                let mut properties = Vec::new();
159
160                // If V5, parse properties before extracting payload
161                if self.protocol_level == ProtocolLevel::V5 {
162                    if let Some((props_len, _)) = read_var_int(&mut payload_cursor)? {
163                        let props_end = payload_cursor.position() as usize + props_len as usize;
164                        if total_len < header_len + props_end {
165                            return Err(MqttError::MalformedPacket(
166                                "Properties length exceeds packet",
167                            ));
168                        }
169                        properties = parse_properties(&mut payload_cursor, props_len as usize)?;
170                    } else {
171                        return Err(MqttError::MalformedPacket(
172                            "Incomplete v5 properties in PUBLISH",
173                        ));
174                    }
175                }
176
177                // Payload is the rest of the packet
178                let payload_start = header_len + payload_cursor.position() as usize;
179                let payload = packet_bytes.slice(payload_start..total_len);
180
181                MqttPacket::Publish(Publish {
182                    dup,
183                    qos,
184                    retain,
185                    topic,
186                    packet_id,
187                    properties,
188                    payload,
189                })
190            }
191            4 => {
192                let packet_id = payload_cursor.get_u16();
193                // MQTT v5: reason code follows packet_id if remaining_length > 2
194                let reason_code =
195                    if self.protocol_level == ProtocolLevel::V5 && remaining_length > 2 {
196                        Some(payload_cursor.get_u8())
197                    } else {
198                        None
199                    };
200                MqttPacket::PubAck(PubAck {
201                    packet_id,
202                    reason_code,
203                })
204            }
205            5 => MqttPacket::PubRec(PubRec {
206                packet_id: payload_cursor.get_u16(),
207            }),
208            6 => MqttPacket::PubRel(PubRel {
209                packet_id: payload_cursor.get_u16(),
210            }),
211            7 => MqttPacket::PubComp(PubComp {
212                packet_id: payload_cursor.get_u16(),
213            }),
214            8 => {
215                // SUBSCRIBE
216                let packet_id = payload_cursor.get_u16();
217                let mut filters = Vec::new();
218                while payload_cursor.has_remaining() {
219                    let topic_len = payload_cursor.get_u16() as usize;
220                    let mut topic_bytes = vec![0; topic_len];
221                    payload_cursor.copy_to_slice(&mut topic_bytes);
222                    let topic = String::from_utf8_lossy(&topic_bytes).to_string();
223                    let qos = payload_cursor.get_u8();
224                    filters.push((topic, qos));
225                }
226                MqttPacket::Subscribe(Subscribe { packet_id, filters })
227            }
228            9 => {
229                // SUBACK
230                let packet_id = payload_cursor.get_u16();
231                let mut return_codes = Vec::new();
232                while payload_cursor.has_remaining() {
233                    return_codes.push(payload_cursor.get_u8());
234                }
235                MqttPacket::SubAck(SubAck {
236                    packet_id,
237                    return_codes,
238                })
239            }
240            10 => {
241                // UNSUBSCRIBE
242                let packet_id = payload_cursor.get_u16();
243                let mut filters = Vec::new();
244                while payload_cursor.has_remaining() {
245                    let topic_len = payload_cursor.get_u16() as usize;
246                    let mut topic_bytes = vec![0; topic_len];
247                    payload_cursor.copy_to_slice(&mut topic_bytes);
248                    filters.push(String::from_utf8_lossy(&topic_bytes).to_string());
249                }
250                MqttPacket::Unsubscribe(Unsubscribe { packet_id, filters })
251            }
252            11 => MqttPacket::UnsubAck(UnsubAck {
253                packet_id: payload_cursor.get_u16(),
254            }),
255            12 => MqttPacket::PingReq,
256            13 => MqttPacket::PingResp,
257            14 => MqttPacket::Disconnect,
258            _ => {
259                return Err(MqttError::ProtocolError(format!(
260                    "Unsupported packet type: {}",
261                    packet_type
262                )))
263            }
264        };
265
266        Ok(Some(packet))
267    }
268}
269
270impl Encoder<MqttPacket> for MqttCodec {
271    type Error = MqttError;
272
273    fn encode(&mut self, item: MqttPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
274        match item {
275            MqttPacket::ConnAck(connack) => {
276                dst.put_u8(0x20); // Type 2 (CONNACK)
277                dst.put_u8(2); // Remaining length is always 2 for v3.1.1
278                dst.put_u8(if connack.session_present { 1 } else { 0 });
279                dst.put_u8(connack.return_code);
280            }
281            MqttPacket::PingResp => {
282                dst.put_u8(0xD0); // Type 13 (PINGRESP)
283                dst.put_u8(0); // Remaining length is 0
284            }
285            MqttPacket::PubAck(puback) => {
286                dst.put_u8(0x40);
287                if self.protocol_level == ProtocolLevel::V5 {
288                    let reason = puback.reason_code.unwrap_or(0x00);
289                    if reason == 0x00 {
290                        // MQTT v5 ยง3.4.2.1: if Reason Code is 0x00 and there
291                        // are no Properties, the Reason Code and Property
292                        // Length can be omitted (short form).
293                        dst.put_u8(2); // Remaining length: 2 (packet_id only)
294                        dst.put_u16(puback.packet_id);
295                    } else {
296                        // Non-success: must include reason code + empty properties
297                        dst.put_u8(4); // Remaining length: 2 (ID) + 1 (reason) + 1 (props=0)
298                        dst.put_u16(puback.packet_id);
299                        dst.put_u8(reason);
300                        dst.put_u8(0); // 0 Properties
301                    }
302                } else {
303                    // MQTT v3.1.1: no reason code
304                    dst.put_u8(2); // Remaining length: 2 (ID)
305                    dst.put_u16(puback.packet_id);
306                }
307            }
308            MqttPacket::PubRec(pubrec) => {
309                dst.put_u8(0x50);
310                dst.put_u8(2);
311                dst.put_u16(pubrec.packet_id);
312            }
313            MqttPacket::PubRel(pubrel) => {
314                dst.put_u8(0x62);
315                dst.put_u8(2);
316                dst.put_u16(pubrel.packet_id);
317            }
318            MqttPacket::PubComp(pubcomp) => {
319                dst.put_u8(0x70);
320                dst.put_u8(2);
321                dst.put_u16(pubcomp.packet_id);
322            }
323            MqttPacket::SubAck(suback) => {
324                dst.put_u8(0x90);
325                // Length is 2 (packet id) + number of return codes + property length (if v5)
326                let props_len = if self.protocol_level == ProtocolLevel::V5 {
327                    1
328                } else {
329                    0
330                };
331                let remaining_len = 2 + suback.return_codes.len() as u32 + props_len;
332                crate::utils::write_var_int(remaining_len, dst)?;
333                dst.put_u16(suback.packet_id);
334
335                if self.protocol_level == ProtocolLevel::V5 {
336                    dst.put_u8(0); // 0 Properties
337                }
338
339                for rc in suback.return_codes {
340                    dst.put_u8(rc);
341                }
342            }
343            MqttPacket::UnsubAck(unsuback) => {
344                dst.put_u8(0xB0);
345                dst.put_u8(2);
346                dst.put_u16(unsuback.packet_id);
347            }
348            MqttPacket::PingReq => {
349                dst.put_u8(0xC0);
350                dst.put_u8(0);
351            }
352
353            MqttPacket::Disconnect => {
354                dst.put_u8(0xE0);
355                dst.put_u8(0);
356            }
357            _ => {
358                return Err(MqttError::ProtocolError(
359                    "Packet encoding not implemented for this type".into(),
360                ))
361            }
362        }
363        Ok(())
364    }
365}
366
367pub fn parse_properties(
368    cursor: &mut Cursor<&[u8]>,
369    length: usize,
370) -> Result<Vec<Property>, MqttError> {
371    let mut properties = Vec::new();
372    let start_pos = cursor.position() as usize;
373
374    while (cursor.position() as usize - start_pos) < length {
375        if let Some((identifier, _)) = read_var_int(cursor)? {
376            match identifier {
377                0x01 => properties.push(Property::PayloadFormatIndicator(cursor.get_u8())),
378                0x02 => properties.push(Property::MessageExpiryInterval(cursor.get_u32())),
379                0x03 => {
380                    let str_len = cursor.get_u16() as usize;
381                    let mut str_bytes = vec![0; str_len];
382                    cursor.copy_to_slice(&mut str_bytes);
383                    properties.push(Property::ContentType(
384                        String::from_utf8_lossy(&str_bytes).to_string(),
385                    ));
386                }
387                0x08 => {
388                    let str_len = cursor.get_u16() as usize;
389                    let mut str_bytes = vec![0; str_len];
390                    cursor.copy_to_slice(&mut str_bytes);
391                    properties.push(Property::ResponseTopic(
392                        String::from_utf8_lossy(&str_bytes).to_string(),
393                    ));
394                }
395                0x09 => {
396                    let bin_len = cursor.get_u16() as usize;
397                    let mut bin_bytes = vec![0; bin_len];
398                    cursor.copy_to_slice(&mut bin_bytes);
399                    properties.push(Property::CorrelationData(bin_bytes));
400                }
401                0x0B => {
402                    if let Some((sub_id, _)) = read_var_int(cursor)? {
403                        properties.push(Property::SubscriptionIdentifier(sub_id));
404                    }
405                }
406                0x23 => properties.push(Property::TopicAlias(cursor.get_u16())),
407                0x26 => {
408                    let k_len = cursor.get_u16() as usize;
409                    let mut k_bytes = vec![0; k_len];
410                    cursor.copy_to_slice(&mut k_bytes);
411                    let v_len = cursor.get_u16() as usize;
412                    let mut v_bytes = vec![0; v_len];
413                    cursor.copy_to_slice(&mut v_bytes);
414                    properties.push(Property::UserProperty(
415                        String::from_utf8_lossy(&k_bytes).to_string(),
416                        String::from_utf8_lossy(&v_bytes).to_string(),
417                    ));
418                }
419                _ => return Err(MqttError::MalformedPacket("Unknown property identifier")),
420            }
421        } else {
422            break;
423        }
424    }
425
426    Ok(properties)
427}