mqtt-frame 0.1.5

A lightweight, Sans-I/O MQTT v3.1.1 and v5.0 protocol codec and parser for Danube.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
use bytes::{Buf, BufMut, BytesMut};
use std::io::Cursor;
use tokio_util::codec::{Decoder, Encoder};

use crate::error::MqttError;
use crate::packet::{
    Connect, MqttPacket, Property, ProtocolLevel, PubAck, PubComp, PubRec, PubRel, Publish, SubAck,
    Subscribe, UnsubAck, Unsubscribe,
};
use crate::utils::read_var_int;

pub struct MqttCodec {
    /// Tracks the protocol level negotiated during CONNECT.
    /// Defaults to V311 for the first packet.
    pub protocol_level: ProtocolLevel,
    /// Maximum allowed remaining length (bytes) for any MQTT packet.
    /// `None` = no limit (backward-compatible default).
    /// When exceeded, `decode()` returns `MqttError::PayloadTooLarge`
    /// **before** allocating memory for the packet body.
    pub max_packet_size: Option<usize>,
}

impl Default for MqttCodec {
    fn default() -> Self {
        Self::new()
    }
}

impl MqttCodec {
    pub fn new() -> Self {
        Self {
            protocol_level: ProtocolLevel::V311,
            max_packet_size: None,
        }
    }

    /// Create a codec with a maximum packet size limit.
    ///
    /// Any packet whose `remaining_length` exceeds `max_size` is rejected
    /// at the decode level before memory is allocated.
    pub fn with_max_packet_size(max_size: usize) -> Self {
        Self {
            protocol_level: ProtocolLevel::V311,
            max_packet_size: Some(max_size),
        }
    }
}

impl Decoder for MqttCodec {
    type Item = MqttPacket;
    type Error = MqttError;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        if src.is_empty() {
            return Ok(None);
        }

        let mut cursor = Cursor::new(&src[..]);
        let fixed_header = cursor.get_u8();
        let packet_type = fixed_header >> 4;
        let flags = fixed_header & 0x0F;

        let var_int_result = read_var_int(&mut cursor)?;
        let remaining_length = match var_int_result {
            Some((len, _)) => len as usize,
            None => return Ok(None), // Not enough data for length
        };

        let header_len = cursor.position() as usize;
        let total_len = header_len + remaining_length;

        // Enforce max packet size BEFORE allocating memory.
        // This prevents a single oversized PUBLISH from exhausting server RAM.
        if let Some(max) = self.max_packet_size {
            if remaining_length > max {
                // If the oversized packet is already fully buffered, drain it
                // to prevent the decoder from re-reading the same bytes forever.
                if src.len() >= total_len {
                    src.advance(total_len);
                }
                return Err(MqttError::PayloadTooLarge {
                    size: remaining_length,
                    limit: max,
                });
            }
        }

        if src.len() < total_len {
            src.reserve(total_len - src.len());
            return Ok(None); // Wait for more data
        }

        // We have the full packet. Let's slice it out using zero-copy.
        let packet_bytes = src.split_to(total_len).freeze();

        let mut payload_cursor = Cursor::new(&packet_bytes[header_len..]);

        let packet = match packet_type {
            1 => {
                // CONNECT
                let protocol_name_len = payload_cursor.get_u16() as usize;
                let mut protocol_name = vec![0; protocol_name_len];
                payload_cursor.copy_to_slice(&mut protocol_name);

                let protocol_level_byte = payload_cursor.get_u8();
                let protocol_level = match protocol_level_byte {
                    4 => ProtocolLevel::V311,
                    5 => ProtocolLevel::V5,
                    _ => return Err(MqttError::UnsupportedVersion),
                };

                let connect_flags = payload_cursor.get_u8();
                let clean_session = (connect_flags & 0x02) != 0;
                let keep_alive = payload_cursor.get_u16();

                // Properties (if v5)
                if protocol_level == ProtocolLevel::V5 {
                    if let Some((props_len, _)) = read_var_int(&mut payload_cursor)? {
                        payload_cursor.advance(props_len as usize); // Skip properties for now
                    } else {
                        return Err(MqttError::MalformedPacket("Incomplete v5 properties"));
                    }
                }

                // Client ID
                let client_id_len = payload_cursor.get_u16() as usize;
                let mut client_id_bytes = vec![0; client_id_len];
                payload_cursor.copy_to_slice(&mut client_id_bytes);
                let client_id = String::from_utf8_lossy(&client_id_bytes).to_string();

                // Self-update the codec's protocol level for subsequent packets!
                self.protocol_level = protocol_level;

                MqttPacket::Connect(Connect {
                    protocol_level,
                    client_id,
                    clean_session,
                    keep_alive,
                })
            }
            3 => {
                // PUBLISH
                let dup = (flags & 0x08) != 0;
                let qos = (flags & 0x06) >> 1;
                let retain = (flags & 0x01) != 0;

                let topic_len = payload_cursor.get_u16() as usize;
                let mut topic_bytes = vec![0; topic_len];
                payload_cursor.copy_to_slice(&mut topic_bytes);
                let topic = String::from_utf8_lossy(&topic_bytes).to_string();

                let packet_id = if qos > 0 {
                    Some(payload_cursor.get_u16())
                } else {
                    None
                };

                let mut properties = Vec::new();

                // If V5, parse properties before extracting payload
                if self.protocol_level == ProtocolLevel::V5 {
                    if let Some((props_len, _)) = read_var_int(&mut payload_cursor)? {
                        let props_end = payload_cursor.position() as usize + props_len as usize;
                        if total_len < header_len + props_end {
                            return Err(MqttError::MalformedPacket(
                                "Properties length exceeds packet",
                            ));
                        }
                        properties = parse_properties(&mut payload_cursor, props_len as usize)?;
                    } else {
                        return Err(MqttError::MalformedPacket(
                            "Incomplete v5 properties in PUBLISH",
                        ));
                    }
                }

                // Payload is the rest of the packet
                let payload_start = header_len + payload_cursor.position() as usize;
                let payload = packet_bytes.slice(payload_start..total_len);

                MqttPacket::Publish(Publish {
                    dup,
                    qos,
                    retain,
                    topic,
                    packet_id,
                    properties,
                    payload,
                })
            }
            4 => {
                let packet_id = payload_cursor.get_u16();
                // MQTT v5: reason code follows packet_id if remaining_length > 2
                let reason_code =
                    if self.protocol_level == ProtocolLevel::V5 && remaining_length > 2 {
                        Some(payload_cursor.get_u8())
                    } else {
                        None
                    };
                MqttPacket::PubAck(PubAck {
                    packet_id,
                    reason_code,
                })
            }
            5 => MqttPacket::PubRec(PubRec {
                packet_id: payload_cursor.get_u16(),
            }),
            6 => MqttPacket::PubRel(PubRel {
                packet_id: payload_cursor.get_u16(),
            }),
            7 => MqttPacket::PubComp(PubComp {
                packet_id: payload_cursor.get_u16(),
            }),
            8 => {
                // SUBSCRIBE
                let packet_id = payload_cursor.get_u16();
                let mut filters = Vec::new();
                while payload_cursor.has_remaining() {
                    let topic_len = payload_cursor.get_u16() as usize;
                    let mut topic_bytes = vec![0; topic_len];
                    payload_cursor.copy_to_slice(&mut topic_bytes);
                    let topic = String::from_utf8_lossy(&topic_bytes).to_string();
                    let qos = payload_cursor.get_u8();
                    filters.push((topic, qos));
                }
                MqttPacket::Subscribe(Subscribe { packet_id, filters })
            }
            9 => {
                // SUBACK
                let packet_id = payload_cursor.get_u16();
                let mut return_codes = Vec::new();
                while payload_cursor.has_remaining() {
                    return_codes.push(payload_cursor.get_u8());
                }
                MqttPacket::SubAck(SubAck {
                    packet_id,
                    return_codes,
                })
            }
            10 => {
                // UNSUBSCRIBE
                let packet_id = payload_cursor.get_u16();
                let mut filters = Vec::new();
                while payload_cursor.has_remaining() {
                    let topic_len = payload_cursor.get_u16() as usize;
                    let mut topic_bytes = vec![0; topic_len];
                    payload_cursor.copy_to_slice(&mut topic_bytes);
                    filters.push(String::from_utf8_lossy(&topic_bytes).to_string());
                }
                MqttPacket::Unsubscribe(Unsubscribe { packet_id, filters })
            }
            11 => MqttPacket::UnsubAck(UnsubAck {
                packet_id: payload_cursor.get_u16(),
            }),
            12 => MqttPacket::PingReq,
            13 => MqttPacket::PingResp,
            14 => MqttPacket::Disconnect,
            _ => {
                return Err(MqttError::ProtocolError(format!(
                    "Unsupported packet type: {}",
                    packet_type
                )))
            }
        };

        Ok(Some(packet))
    }
}

impl Encoder<MqttPacket> for MqttCodec {
    type Error = MqttError;

    fn encode(&mut self, item: MqttPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
        match item {
            MqttPacket::ConnAck(connack) => {
                dst.put_u8(0x20); // Type 2 (CONNACK)
                dst.put_u8(2); // Remaining length is always 2 for v3.1.1
                dst.put_u8(if connack.session_present { 1 } else { 0 });
                dst.put_u8(connack.return_code);
            }
            MqttPacket::PingResp => {
                dst.put_u8(0xD0); // Type 13 (PINGRESP)
                dst.put_u8(0); // Remaining length is 0
            }
            MqttPacket::PubAck(puback) => {
                dst.put_u8(0x40);
                if self.protocol_level == ProtocolLevel::V5 {
                    let reason = puback.reason_code.unwrap_or(0x00);
                    if reason == 0x00 {
                        // MQTT v5 §3.4.2.1: if Reason Code is 0x00 and there
                        // are no Properties, the Reason Code and Property
                        // Length can be omitted (short form).
                        dst.put_u8(2); // Remaining length: 2 (packet_id only)
                        dst.put_u16(puback.packet_id);
                    } else {
                        // Non-success: must include reason code + empty properties
                        dst.put_u8(4); // Remaining length: 2 (ID) + 1 (reason) + 1 (props=0)
                        dst.put_u16(puback.packet_id);
                        dst.put_u8(reason);
                        dst.put_u8(0); // 0 Properties
                    }
                } else {
                    // MQTT v3.1.1: no reason code
                    dst.put_u8(2); // Remaining length: 2 (ID)
                    dst.put_u16(puback.packet_id);
                }
            }
            MqttPacket::PubRec(pubrec) => {
                dst.put_u8(0x50);
                dst.put_u8(2);
                dst.put_u16(pubrec.packet_id);
            }
            MqttPacket::PubRel(pubrel) => {
                dst.put_u8(0x62);
                dst.put_u8(2);
                dst.put_u16(pubrel.packet_id);
            }
            MqttPacket::PubComp(pubcomp) => {
                dst.put_u8(0x70);
                dst.put_u8(2);
                dst.put_u16(pubcomp.packet_id);
            }
            MqttPacket::SubAck(suback) => {
                dst.put_u8(0x90);
                // Length is 2 (packet id) + number of return codes + property length (if v5)
                let props_len = if self.protocol_level == ProtocolLevel::V5 {
                    1
                } else {
                    0
                };
                let remaining_len = 2 + suback.return_codes.len() as u32 + props_len;
                crate::utils::write_var_int(remaining_len, dst)?;
                dst.put_u16(suback.packet_id);

                if self.protocol_level == ProtocolLevel::V5 {
                    dst.put_u8(0); // 0 Properties
                }

                for rc in suback.return_codes {
                    dst.put_u8(rc);
                }
            }
            MqttPacket::UnsubAck(unsuback) => {
                dst.put_u8(0xB0);
                dst.put_u8(2);
                dst.put_u16(unsuback.packet_id);
            }
            MqttPacket::PingReq => {
                dst.put_u8(0xC0);
                dst.put_u8(0);
            }

            MqttPacket::Disconnect => {
                dst.put_u8(0xE0);
                dst.put_u8(0);
            }
            _ => {
                return Err(MqttError::ProtocolError(
                    "Packet encoding not implemented for this type".into(),
                ))
            }
        }
        Ok(())
    }
}

pub fn parse_properties(
    cursor: &mut Cursor<&[u8]>,
    length: usize,
) -> Result<Vec<Property>, MqttError> {
    let mut properties = Vec::new();
    let start_pos = cursor.position() as usize;

    while (cursor.position() as usize - start_pos) < length {
        if let Some((identifier, _)) = read_var_int(cursor)? {
            match identifier {
                0x01 => properties.push(Property::PayloadFormatIndicator(cursor.get_u8())),
                0x02 => properties.push(Property::MessageExpiryInterval(cursor.get_u32())),
                0x03 => {
                    let str_len = cursor.get_u16() as usize;
                    let mut str_bytes = vec![0; str_len];
                    cursor.copy_to_slice(&mut str_bytes);
                    properties.push(Property::ContentType(
                        String::from_utf8_lossy(&str_bytes).to_string(),
                    ));
                }
                0x08 => {
                    let str_len = cursor.get_u16() as usize;
                    let mut str_bytes = vec![0; str_len];
                    cursor.copy_to_slice(&mut str_bytes);
                    properties.push(Property::ResponseTopic(
                        String::from_utf8_lossy(&str_bytes).to_string(),
                    ));
                }
                0x09 => {
                    let bin_len = cursor.get_u16() as usize;
                    let mut bin_bytes = vec![0; bin_len];
                    cursor.copy_to_slice(&mut bin_bytes);
                    properties.push(Property::CorrelationData(bin_bytes));
                }
                0x0B => {
                    if let Some((sub_id, _)) = read_var_int(cursor)? {
                        properties.push(Property::SubscriptionIdentifier(sub_id));
                    }
                }
                0x23 => properties.push(Property::TopicAlias(cursor.get_u16())),
                0x26 => {
                    let k_len = cursor.get_u16() as usize;
                    let mut k_bytes = vec![0; k_len];
                    cursor.copy_to_slice(&mut k_bytes);
                    let v_len = cursor.get_u16() as usize;
                    let mut v_bytes = vec![0; v_len];
                    cursor.copy_to_slice(&mut v_bytes);
                    properties.push(Property::UserProperty(
                        String::from_utf8_lossy(&k_bytes).to_string(),
                        String::from_utf8_lossy(&v_bytes).to_string(),
                    ));
                }
                _ => return Err(MqttError::MalformedPacket("Unknown property identifier")),
            }
        } else {
            break;
        }
    }

    Ok(properties)
}