mqtt_protocol_core/mqtt/connection/
packet_builder.rs

1/**
2 * MIT License
3 *
4 * Copyright (c) 2025 Takatoshi Kondo
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to deal
8 * in the Software without restriction, including without limitation the rights
9 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 * copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24use crate::mqtt::result_code::MqttError;
25use std::io::{Cursor, Read};
26use std::sync::Arc;
27
28#[derive(Debug, Clone)]
29pub enum PacketData {
30    Normal(Vec<u8>),
31    Publish(Arc<[u8]>),
32}
33
34impl PacketData {
35    pub fn as_slice(&self) -> &[u8] {
36        match self {
37            PacketData::Normal(vec) => vec.as_slice(),
38            PacketData::Publish(arc) => arc.as_ref(),
39        }
40    }
41
42    pub fn len(&self) -> u32 {
43        match self {
44            PacketData::Normal(vec) => vec.len().try_into().unwrap(),
45            PacketData::Publish(arc) => arc.len().try_into().unwrap(),
46        }
47    }
48
49    pub fn is_empty(&self) -> bool {
50        self.len() == 0
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct RawPacket {
56    fixed_header: u8,
57    pub data: PacketData,
58}
59
60impl RawPacket {
61    pub fn data_as_slice(&self) -> &[u8] {
62        self.data.as_slice()
63    }
64
65    pub fn packet_type(&self) -> u8 {
66        self.fixed_header >> 4
67    }
68
69    pub fn flags(&self) -> u8 {
70        self.fixed_header & 0x0F
71    }
72
73    pub fn is_publish(&self) -> bool {
74        self.packet_type() == 3
75    }
76
77    pub fn remaining_length(&self) -> u32 {
78        self.data.len()
79    }
80}
81
82/// Enum representing packet construction results
83#[derive(Debug)]
84pub enum PacketBuildResult {
85    /// Packet construction completed
86    Complete(RawPacket),
87    /// Packet building in progress (more data needed)
88    Incomplete,
89    /// Error occurred
90    Error(MqttError),
91}
92
93/// Builder for constructing MQTT packet byte sequences
94pub struct PacketBuilder {
95    /// Current read state
96    state: ReadState,
97    /// Buffer for header and remaining length
98    header_buf: Vec<u8>,
99    /// Remaining length
100    remaining_length: usize,
101    /// Multiplier for variable-length integer decoding
102    multiplier: u32,
103    /// Buffer for entire packet
104    raw_buf: Option<Vec<u8>>,
105    /// Current position in buffer
106    raw_buf_offset: usize,
107}
108
109/// Packet reading state
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111enum ReadState {
112    /// Reading fixed header
113    FixedHeader,
114    /// Reading remaining length
115    RemainingLength,
116    /// Reading payload
117    Payload,
118}
119
120impl PacketBuilder {
121    /// Create new packet builder
122    pub fn new() -> Self {
123        Self {
124            state: ReadState::FixedHeader,
125            header_buf: Vec::with_capacity(5),
126            remaining_length: 0,
127            multiplier: 1,
128            raw_buf: None,
129            raw_buf_offset: 0,
130        }
131    }
132
133    /// Reset builder for reuse
134    pub fn reset(&mut self) {
135        self.state = ReadState::FixedHeader;
136        self.header_buf.clear();
137        self.remaining_length = 0;
138        self.multiplier = 1;
139        self.raw_buf = None;
140        self.raw_buf_offset = 0;
141    }
142
143    /// Get packet type (first byte of fixed header)
144    fn get_packet_type(&self) -> u8 {
145        if !self.header_buf.is_empty() {
146            self.header_buf[0]
147        } else {
148            0
149        }
150    }
151
152    /// Determine if packet is PUBLISH
153    fn is_publish_packet(&self) -> bool {
154        (self.get_packet_type() & 0xF0) == 0x30
155    }
156
157    /// Build packet from data stream
158    pub fn feed(&mut self, data: &mut Cursor<&[u8]>) -> PacketBuildResult {
159        let available = data.get_ref().len() as u64 - data.position();
160        if available == 0 {
161            return PacketBuildResult::Incomplete;
162        }
163
164        let mut byte = [0u8; 1];
165
166        loop {
167            match self.state {
168                ReadState::FixedHeader => {
169                    if data.read_exact(&mut byte).is_err() {
170                        return PacketBuildResult::Incomplete;
171                    }
172
173                    self.header_buf.push(byte[0]);
174                    self.state = ReadState::RemainingLength;
175                }
176
177                ReadState::RemainingLength => {
178                    if data.read_exact(&mut byte).is_err() {
179                        return PacketBuildResult::Incomplete;
180                    }
181
182                    self.header_buf.push(byte[0]);
183                    let encoded_byte = byte[0];
184
185                    self.remaining_length +=
186                        ((encoded_byte & 0x7F) as usize) * (self.multiplier as usize);
187                    self.multiplier *= 128;
188
189                    // Variable-length integer limit check
190                    if self.multiplier > 128 * 128 * 128 {
191                        self.reset();
192                        return PacketBuildResult::Error(MqttError::MalformedPacket);
193                    }
194
195                    if (encoded_byte & 0x80) == 0 {
196                        if self.remaining_length == 0 {
197                            let fixed_header = self.header_buf[0];
198                            let packet_data = if self.is_publish_packet() {
199                                // Use Arc for PUBLISH packets
200                                PacketData::Publish(Arc::from([]))
201                            } else {
202                                // Use Vec for other packets
203                                PacketData::Normal(Vec::new())
204                            };
205
206                            let packet = RawPacket {
207                                fixed_header,
208                                data: packet_data,
209                            };
210                            self.reset();
211                            return PacketBuildResult::Complete(packet);
212                        } else {
213                            self.raw_buf = Some(Vec::with_capacity(self.remaining_length));
214                            self.raw_buf_offset = 0;
215                            self.state = ReadState::Payload;
216                        }
217                    }
218                }
219
220                ReadState::Payload => {
221                    let raw_buf = self.raw_buf.as_mut().unwrap();
222                    let bytes_remaining = self.remaining_length;
223
224                    let position = data.position();
225                    let available = data.get_ref().len() as u64 - position;
226                    let bytes_to_read = bytes_remaining.min(available as usize);
227
228                    if bytes_to_read == 0 {
229                        return PacketBuildResult::Incomplete;
230                    }
231
232                    raw_buf.resize(self.raw_buf_offset + bytes_to_read, 0);
233
234                    let read_slice =
235                        &mut raw_buf[self.raw_buf_offset..self.raw_buf_offset + bytes_to_read];
236                    let bytes_read = data.read(read_slice).unwrap();
237
238                    self.raw_buf_offset += bytes_read;
239                    self.remaining_length -= bytes_read;
240
241                    if self.remaining_length == 0 {
242                        let raw_buf = self.raw_buf.take().unwrap();
243                        let fixed_header = self.header_buf[0];
244
245                        let packet_data = if self.is_publish_packet() {
246                            // Use Arc for PUBLISH packets
247                            PacketData::Publish(Arc::from(raw_buf.into_boxed_slice()))
248                        } else {
249                            // Use Vec for other packets
250                            PacketData::Normal(raw_buf)
251                        };
252
253                        let packet = RawPacket {
254                            fixed_header,
255                            data: packet_data,
256                        };
257                        self.reset();
258                        return PacketBuildResult::Complete(packet);
259                    }
260                    return PacketBuildResult::Incomplete;
261                }
262            }
263        }
264    }
265}