mqtt_protocol_core/mqtt/connection/
packet_builder.rs

1// MIT License
2//
3// Copyright (c) 2025 Takatoshi Kondo
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22use crate::mqtt::common::Cursor;
23use crate::mqtt::result_code::MqttError;
24use alloc::{sync::Arc, vec::Vec};
25
26#[derive(Debug, Clone)]
27pub enum PacketData {
28    Normal(Vec<u8>),
29    Publish(Arc<[u8]>),
30}
31
32impl PacketData {
33    pub fn as_slice(&self) -> &[u8] {
34        match self {
35            PacketData::Normal(vec) => vec.as_slice(),
36            PacketData::Publish(arc) => arc.as_ref(),
37        }
38    }
39
40    pub fn len(&self) -> u32 {
41        match self {
42            PacketData::Normal(vec) => vec.len().try_into().unwrap(),
43            PacketData::Publish(arc) => arc.len().try_into().unwrap(),
44        }
45    }
46
47    pub fn is_empty(&self) -> bool {
48        self.len() == 0
49    }
50}
51
52#[derive(Debug, Clone)]
53pub struct RawPacket {
54    fixed_header: u8,
55    pub data: PacketData,
56}
57
58impl RawPacket {
59    pub fn data_as_slice(&self) -> &[u8] {
60        self.data.as_slice()
61    }
62
63    pub fn packet_type(&self) -> u8 {
64        self.fixed_header >> 4
65    }
66
67    pub fn flags(&self) -> u8 {
68        self.fixed_header & 0x0F
69    }
70
71    pub fn is_publish(&self) -> bool {
72        self.packet_type() == 3
73    }
74
75    pub fn remaining_length(&self) -> u32 {
76        self.data.len()
77    }
78}
79
80/// Enum representing packet construction results
81#[derive(Debug)]
82pub enum PacketBuildResult {
83    /// Packet construction completed
84    Complete(RawPacket),
85    /// Packet building in progress (more data needed)
86    Incomplete,
87    /// Error occurred
88    Error(MqttError),
89}
90
91/// Builder for constructing MQTT packet byte sequences
92pub struct PacketBuilder {
93    /// Current read state
94    state: ReadState,
95    /// Buffer for header and remaining length
96    header_buf: Vec<u8>,
97    /// Remaining length
98    remaining_length: usize,
99    /// Multiplier for variable-length integer decoding
100    multiplier: u32,
101    /// Buffer for entire packet
102    raw_buf: Option<Vec<u8>>,
103    /// Current position in buffer
104    raw_buf_offset: usize,
105}
106
107/// Packet reading state
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109enum ReadState {
110    /// Reading fixed header
111    FixedHeader,
112    /// Reading remaining length
113    RemainingLength,
114    /// Reading payload
115    Payload,
116}
117
118impl PacketBuilder {
119    /// Create new packet builder
120    pub fn new() -> Self {
121        Self {
122            state: ReadState::FixedHeader,
123            header_buf: Vec::with_capacity(5),
124            remaining_length: 0,
125            multiplier: 1,
126            raw_buf: None,
127            raw_buf_offset: 0,
128        }
129    }
130
131    /// Reset builder for reuse
132    pub fn reset(&mut self) {
133        self.state = ReadState::FixedHeader;
134        self.header_buf.clear();
135        self.remaining_length = 0;
136        self.multiplier = 1;
137        self.raw_buf = None;
138        self.raw_buf_offset = 0;
139    }
140
141    /// Get packet type (first byte of fixed header)
142    fn get_packet_type(&self) -> u8 {
143        if !self.header_buf.is_empty() {
144            self.header_buf[0]
145        } else {
146            0
147        }
148    }
149
150    /// Determine if packet is PUBLISH
151    fn is_publish_packet(&self) -> bool {
152        (self.get_packet_type() & 0xF0) == 0x30
153    }
154
155    /// Build packet from data stream
156    pub fn feed(&mut self, data: &mut Cursor<&[u8]>) -> PacketBuildResult {
157        let available = data.get_ref().len() as u64 - data.position();
158        if available == 0 {
159            return PacketBuildResult::Incomplete;
160        }
161
162        let mut byte = [0u8; 1];
163
164        loop {
165            match self.state {
166                ReadState::FixedHeader => {
167                    if data.read_exact(&mut byte).is_err() {
168                        return PacketBuildResult::Incomplete;
169                    }
170
171                    self.header_buf.push(byte[0]);
172                    self.state = ReadState::RemainingLength;
173                }
174
175                ReadState::RemainingLength => {
176                    if data.read_exact(&mut byte).is_err() {
177                        return PacketBuildResult::Incomplete;
178                    }
179
180                    self.header_buf.push(byte[0]);
181                    let encoded_byte = byte[0];
182
183                    self.remaining_length +=
184                        ((encoded_byte & 0x7F) as usize) * (self.multiplier as usize);
185                    self.multiplier *= 128;
186
187                    // Variable-length integer limit check
188                    if self.multiplier > 128 * 128 * 128 {
189                        self.reset();
190                        return PacketBuildResult::Error(MqttError::MalformedPacket);
191                    }
192
193                    if (encoded_byte & 0x80) == 0 {
194                        if self.remaining_length == 0 {
195                            let fixed_header = self.header_buf[0];
196                            let packet_data = if self.is_publish_packet() {
197                                // Use Arc for PUBLISH packets
198                                PacketData::Publish(Arc::from([]))
199                            } else {
200                                // Use Vec for other packets
201                                PacketData::Normal(Vec::new())
202                            };
203
204                            let packet = RawPacket {
205                                fixed_header,
206                                data: packet_data,
207                            };
208                            self.reset();
209                            return PacketBuildResult::Complete(packet);
210                        } else {
211                            self.raw_buf = Some(Vec::with_capacity(self.remaining_length));
212                            self.raw_buf_offset = 0;
213                            self.state = ReadState::Payload;
214                        }
215                    }
216                }
217
218                ReadState::Payload => {
219                    let raw_buf = self.raw_buf.as_mut().unwrap();
220                    let bytes_remaining = self.remaining_length;
221
222                    let position = data.position();
223                    let available = data.get_ref().len() as u64 - position;
224                    let bytes_to_read = bytes_remaining.min(available as usize);
225
226                    if bytes_to_read == 0 {
227                        return PacketBuildResult::Incomplete;
228                    }
229
230                    raw_buf.resize(self.raw_buf_offset + bytes_to_read, 0);
231
232                    let read_slice =
233                        &mut raw_buf[self.raw_buf_offset..self.raw_buf_offset + bytes_to_read];
234                    let bytes_read = data.read(read_slice).unwrap();
235
236                    self.raw_buf_offset += bytes_read;
237                    self.remaining_length -= bytes_read;
238
239                    if self.remaining_length == 0 {
240                        let raw_buf = self.raw_buf.take().unwrap();
241                        let fixed_header = self.header_buf[0];
242
243                        let packet_data = if self.is_publish_packet() {
244                            // Use Arc for PUBLISH packets
245                            PacketData::Publish(Arc::from(raw_buf.into_boxed_slice()))
246                        } else {
247                            // Use Vec for other packets
248                            PacketData::Normal(raw_buf)
249                        };
250
251                        let packet = RawPacket {
252                            fixed_header,
253                            data: packet_data,
254                        };
255                        self.reset();
256                        return PacketBuildResult::Complete(packet);
257                    }
258                    return PacketBuildResult::Incomplete;
259                }
260            }
261        }
262    }
263}