mqtt_protocol_core/mqtt/connection/
packet_builder.rs

1// MIT License
2//
3// Copyright (c) 2025 Takatoshi Kondo
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23use crate::mqtt::common::{Arc, Cursor};
24use crate::mqtt::result_code::MqttError;
25use alloc::vec::Vec;
26
27#[derive(Debug, Clone)]
28pub enum PacketData {
29    Normal(Vec<u8>),
30    Publish(Arc<[u8]>),
31}
32
33impl PacketData {
34    pub fn as_slice(&self) -> &[u8] {
35        match self {
36            PacketData::Normal(vec) => vec.as_slice(),
37            PacketData::Publish(arc) => arc.as_ref(),
38        }
39    }
40
41    pub fn len(&self) -> u32 {
42        match self {
43            PacketData::Normal(vec) => vec.len().try_into().unwrap(),
44            PacketData::Publish(arc) => arc.len().try_into().unwrap(),
45        }
46    }
47
48    pub fn is_empty(&self) -> bool {
49        self.len() == 0
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct RawPacket {
55    fixed_header: u8,
56    pub data: PacketData,
57}
58
59impl RawPacket {
60    pub fn data_as_slice(&self) -> &[u8] {
61        self.data.as_slice()
62    }
63
64    pub fn packet_type(&self) -> u8 {
65        self.fixed_header >> 4
66    }
67
68    pub fn flags(&self) -> u8 {
69        self.fixed_header & 0x0F
70    }
71
72    pub fn is_publish(&self) -> bool {
73        self.packet_type() == 3
74    }
75
76    pub fn remaining_length(&self) -> u32 {
77        self.data.len()
78    }
79}
80
81/// Enum representing packet construction results
82#[derive(Debug)]
83pub enum PacketBuildResult {
84    /// Packet construction completed
85    Complete(RawPacket),
86    /// Packet building in progress (more data needed)
87    Incomplete,
88    /// Error occurred
89    Error(MqttError),
90}
91
92/// Builder for constructing MQTT packet byte sequences
93pub struct PacketBuilder {
94    /// Current read state
95    state: ReadState,
96    /// Buffer for header and remaining length
97    header_buf: Vec<u8>,
98    /// Remaining length
99    remaining_length: usize,
100    /// Multiplier for variable-length integer decoding
101    multiplier: u32,
102    /// Buffer for entire packet
103    raw_buf: Option<Vec<u8>>,
104    /// Current position in buffer
105    raw_buf_offset: usize,
106}
107
108/// Packet reading state
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110enum ReadState {
111    /// Reading fixed header
112    FixedHeader,
113    /// Reading remaining length
114    RemainingLength,
115    /// Reading payload
116    Payload,
117}
118
119impl PacketBuilder {
120    /// Create new packet builder
121    pub fn new() -> Self {
122        Self {
123            state: ReadState::FixedHeader,
124            header_buf: Vec::with_capacity(5),
125            remaining_length: 0,
126            multiplier: 1,
127            raw_buf: None,
128            raw_buf_offset: 0,
129        }
130    }
131
132    /// Reset builder for reuse
133    pub fn reset(&mut self) {
134        self.state = ReadState::FixedHeader;
135        self.header_buf.clear();
136        self.remaining_length = 0;
137        self.multiplier = 1;
138        self.raw_buf = None;
139        self.raw_buf_offset = 0;
140    }
141
142    /// Get packet type (first byte of fixed header)
143    fn get_packet_type(&self) -> u8 {
144        if !self.header_buf.is_empty() {
145            self.header_buf[0]
146        } else {
147            0
148        }
149    }
150
151    /// Determine if packet is PUBLISH
152    fn is_publish_packet(&self) -> bool {
153        (self.get_packet_type() & 0xF0) == 0x30
154    }
155
156    /// Build packet from data stream
157    pub fn feed(&mut self, data: &mut Cursor<&[u8]>) -> PacketBuildResult {
158        let available = data.get_ref().len() as u64 - data.position();
159        if available == 0 {
160            return PacketBuildResult::Incomplete;
161        }
162
163        let mut byte = [0u8; 1];
164
165        loop {
166            match self.state {
167                ReadState::FixedHeader => {
168                    if data.read_exact(&mut byte).is_err() {
169                        return PacketBuildResult::Incomplete;
170                    }
171
172                    self.header_buf.push(byte[0]);
173                    self.state = ReadState::RemainingLength;
174                }
175
176                ReadState::RemainingLength => {
177                    if data.read_exact(&mut byte).is_err() {
178                        return PacketBuildResult::Incomplete;
179                    }
180
181                    self.header_buf.push(byte[0]);
182                    let encoded_byte = byte[0];
183
184                    self.remaining_length +=
185                        ((encoded_byte & 0x7F) as usize) * (self.multiplier as usize);
186                    self.multiplier *= 128;
187
188                    // Variable-length integer limit check
189                    if self.multiplier > 128 * 128 * 128 {
190                        self.reset();
191                        return PacketBuildResult::Error(MqttError::MalformedPacket);
192                    }
193
194                    if (encoded_byte & 0x80) == 0 {
195                        if self.remaining_length == 0 {
196                            let fixed_header = self.header_buf[0];
197                            let packet_data = if self.is_publish_packet() {
198                                // Use Arc for PUBLISH packets
199                                PacketData::Publish(Arc::from([]))
200                            } else {
201                                // Use Vec for other packets
202                                PacketData::Normal(Vec::new())
203                            };
204
205                            let packet = RawPacket {
206                                fixed_header,
207                                data: packet_data,
208                            };
209                            self.reset();
210                            return PacketBuildResult::Complete(packet);
211                        } else {
212                            self.raw_buf = Some(Vec::with_capacity(self.remaining_length));
213                            self.raw_buf_offset = 0;
214                            self.state = ReadState::Payload;
215                        }
216                    }
217                }
218
219                ReadState::Payload => {
220                    let raw_buf = self.raw_buf.as_mut().unwrap();
221                    let bytes_remaining = self.remaining_length;
222
223                    let position = data.position();
224                    let available = data.get_ref().len() as u64 - position;
225                    let bytes_to_read = bytes_remaining.min(available as usize);
226
227                    if bytes_to_read == 0 {
228                        return PacketBuildResult::Incomplete;
229                    }
230
231                    raw_buf.resize(self.raw_buf_offset + bytes_to_read, 0);
232
233                    let read_slice =
234                        &mut raw_buf[self.raw_buf_offset..self.raw_buf_offset + bytes_to_read];
235                    let bytes_read = data.read(read_slice).unwrap();
236
237                    self.raw_buf_offset += bytes_read;
238                    self.remaining_length -= bytes_read;
239
240                    if self.remaining_length == 0 {
241                        let raw_buf = self.raw_buf.take().unwrap();
242                        let fixed_header = self.header_buf[0];
243
244                        let packet_data = if self.is_publish_packet() {
245                            // Use Arc for PUBLISH packets
246                            PacketData::Publish(Arc::from(raw_buf.into_boxed_slice()))
247                        } else {
248                            // Use Vec for other packets
249                            PacketData::Normal(raw_buf)
250                        };
251
252                        let packet = RawPacket {
253                            fixed_header,
254                            data: packet_data,
255                        };
256                        self.reset();
257                        return PacketBuildResult::Complete(packet);
258                    }
259                    return PacketBuildResult::Incomplete;
260                }
261            }
262        }
263    }
264}