ntex_mqtt/v5/codec/
codec.rs

1use std::{cell::Cell, cmp::min, fmt, num::NonZeroU32};
2
3use ntex_bytes::{Buf, BufMut, Bytes, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5
6use crate::error::{DecodeError, EncodeError};
7use crate::types::{packet_type, FixedHeader, MAX_PACKET_SIZE};
8use crate::{payload::Payload, utils, utils::decode_variable_length};
9
10use super::{decode::decode_packet, encode::EncodeLtd, packet::Publish, Packet};
11use super::{Decoded, Encoded};
12
13pub struct Codec {
14    state: Cell<DecodeState>,
15    max_in_size: Cell<u32>,
16    max_out_size: Cell<u32>,
17    min_chunk_size: Cell<u32>,
18    flags: Cell<CodecFlags>,
19    encoding_payload: Cell<Option<NonZeroU32>>,
20}
21
22bitflags::bitflags! {
23    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
24    pub struct CodecFlags: u8 {
25        const NO_PROBLEM_INFO = 0b0000_0001;
26        const NO_RETAIN       = 0b0000_0010;
27        const NO_SUB_IDS      = 0b0000_1000;
28    }
29}
30
31#[derive(Debug, Clone, Copy)]
32enum DecodeState {
33    FrameHeader,
34    Frame(FixedHeader),
35    PublishHeader(FixedHeader),
36    PublishProperties(u32, FixedHeader),
37    PublishPayload(u32),
38}
39
40impl Codec {
41    /// Create `Codec` instance
42    pub fn new() -> Self {
43        Codec {
44            state: Cell::new(DecodeState::FrameHeader),
45            max_in_size: Cell::new(0),
46            max_out_size: Cell::new(0),
47            min_chunk_size: Cell::new(0),
48            flags: Cell::new(CodecFlags::empty()),
49            encoding_payload: Cell::new(None),
50        }
51    }
52
53    /// Set min payload chunk size.
54    ///
55    /// If the minimum size is set to `0`, incoming payload chunks
56    /// will be processed immediately. Otherwise, the codec will
57    /// accumulate chunks until the total size reaches the specified minimum.
58    /// By default min size is set to `0`
59    pub fn set_min_chunk_size(&self, size: u32) {
60        self.min_chunk_size.set(size)
61    }
62
63    /// Set max inbound frame size.
64    ///
65    /// If max size is set to `0`, size is unlimited.
66    /// By default max size is set to `0`
67    pub fn max_inbound_size(&self) -> u32 {
68        self.max_in_size.get()
69    }
70
71    /// Set max outbound frame size.
72    ///
73    /// If max size is set to `0`, size is unlimited.
74    /// By default max size is set to `0`
75    pub fn max_outbound_size(&self) -> u32 {
76        self.max_out_size.get()
77    }
78
79    /// Set max inbound frame size.
80    ///
81    /// If max size is set to `0`, size is unlimited.
82    /// By default max size is set to `0`
83    pub fn set_max_inbound_size(&self, size: u32) {
84        self.max_in_size.set(size);
85    }
86
87    /// Set max outbound frame size.
88    ///
89    /// If max size is set to `0`, size is unlimited.
90    /// By default max size is set to `0`
91    pub fn set_max_outbound_size(&self, mut size: u32) {
92        if size > 5 {
93            // fixed header = 1, var_len(remaining.max_value()) = 4
94            size -= 5;
95        }
96        self.max_out_size.set(size);
97    }
98
99    pub(crate) fn retain_available(&self) -> bool {
100        !self.flags.get().contains(CodecFlags::NO_RETAIN)
101    }
102
103    pub(crate) fn sub_ids_available(&self) -> bool {
104        !self.flags.get().contains(CodecFlags::NO_SUB_IDS)
105    }
106
107    pub(crate) fn set_retain_available(&self, val: bool) {
108        let mut flags = self.flags.get();
109        flags.set(CodecFlags::NO_RETAIN, !val);
110        self.flags.set(flags);
111    }
112
113    pub(crate) fn set_sub_ids_available(&self, val: bool) {
114        let mut flags = self.flags.get();
115        flags.set(CodecFlags::NO_SUB_IDS, !val);
116        self.flags.set(flags);
117    }
118}
119
120impl Default for Codec {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl Decoder for Codec {
127    type Item = super::Decoded;
128    type Error = DecodeError;
129
130    fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
131        loop {
132            match self.state.get() {
133                DecodeState::FrameHeader => {
134                    if src.len() < 2 {
135                        return Ok(None);
136                    }
137                    let src_slice = src.as_ref();
138                    let first_byte = src_slice[0];
139                    match decode_variable_length(&src_slice[1..])? {
140                        Some((remaining_length, consumed)) => {
141                            // check max message size
142                            let max_in_size = self.max_in_size.get();
143                            if max_in_size != 0 && max_in_size < remaining_length {
144                                log::debug!(
145                                    "MaxSizeExceeded max-size: {}, remaining: {}",
146                                    max_in_size,
147                                    remaining_length
148                                );
149                                return Err(DecodeError::MaxSizeExceeded);
150                            }
151                            src.advance(consumed + 1);
152
153                            if packet_type::is_publish(first_byte) {
154                                self.state.set(DecodeState::PublishHeader(FixedHeader {
155                                    first_byte,
156                                    remaining_length,
157                                }));
158                            } else {
159                                self.state.set(DecodeState::Frame(FixedHeader {
160                                    first_byte,
161                                    remaining_length,
162                                }));
163
164                                // todo: validate remaining_length against max frame size config
165                                let remaining_length = remaining_length as usize;
166                                if src.len() < remaining_length {
167                                    // todo: subtract?
168                                    src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager?
169                                    return Ok(None);
170                                }
171                            }
172                        }
173                        None => {
174                            return Ok(None);
175                        }
176                    }
177                }
178                DecodeState::PublishHeader(fixed) => {
179                    if let Some(len) = Publish::packet_header_size(src, fixed.first_byte)? {
180                        self.state.set(DecodeState::PublishProperties(len, fixed));
181                    } else {
182                        return Ok(None);
183                    }
184                }
185                DecodeState::PublishProperties(props_len, fixed) => {
186                    if src.len() < props_len as usize {
187                        return Ok(None);
188                    }
189                    let payload_len = (fixed.remaining_length - props_len);
190                    let mut buf = src.split_to(props_len as usize).freeze();
191                    let publish = Publish::decode(&mut buf, fixed.first_byte, payload_len)?;
192
193                    let len = src.len() as u32;
194                    let min_chunk_size = self.min_chunk_size.get();
195                    if len >= payload_len || min_chunk_size == 0 || len >= min_chunk_size {
196                        let payload =
197                            src.split_to(min(src.len(), payload_len as usize)).freeze();
198                        let remaining = payload_len - payload.len() as u32;
199
200                        if remaining > 0 {
201                            self.state.set(DecodeState::PublishPayload(remaining));
202                        } else {
203                            self.state.set(DecodeState::FrameHeader);
204                            src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
205                        }
206
207                        return Ok(Some(Decoded::Publish(
208                            publish,
209                            payload,
210                            fixed.remaining_length,
211                        )));
212                    } else {
213                        self.state.set(DecodeState::PublishPayload(payload_len));
214                        return Ok(Some(Decoded::Publish(
215                            publish,
216                            Bytes::new(),
217                            fixed.remaining_length,
218                        )));
219                    }
220                }
221                DecodeState::PublishPayload(remaining) => {
222                    let len = src.len() as u32;
223                    let min_chunk_size = self.min_chunk_size.get();
224
225                    return if (len >= remaining)
226                        || (min_chunk_size != 0 && len >= min_chunk_size)
227                    {
228                        let payload = src.split_to(min(src.len(), remaining as usize)).freeze();
229                        let remaining = remaining - payload.len() as u32;
230
231                        let eof = if remaining > 0 {
232                            self.state.set(DecodeState::PublishPayload(remaining));
233                            false
234                        } else {
235                            self.state.set(DecodeState::FrameHeader);
236                            src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
237                            true
238                        };
239                        Ok(Some(Decoded::PayloadChunk(payload, eof)))
240                    } else {
241                        Ok(None)
242                    };
243                }
244                DecodeState::Frame(fixed) => {
245                    return if src.len() < fixed.remaining_length as usize {
246                        Ok(None)
247                    } else {
248                        let packet_buf = src.split_to(fixed.remaining_length as usize).freeze();
249                        let packet = decode_packet(packet_buf, fixed.first_byte)?;
250                        self.state.set(DecodeState::FrameHeader);
251                        src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
252
253                        if let Packet::Connect(ref pkt) = packet {
254                            let mut flags = self.flags.get();
255                            flags.set(CodecFlags::NO_PROBLEM_INFO, !pkt.request_problem_info);
256                            self.flags.set(flags);
257                        }
258                        Ok(Some(Decoded::Packet(packet, fixed.remaining_length)))
259                    };
260                }
261            }
262        }
263    }
264}
265
266impl Encoder for Codec {
267    type Item = Encoded;
268    type Error = EncodeError;
269
270    fn encode(&self, mut item: Self::Item, dst: &mut BytesMut) -> Result<(), EncodeError> {
271        // handle [MQTT 3.1.2.11.7]
272        if self.flags.get().contains(CodecFlags::NO_PROBLEM_INFO) {
273            match item {
274                Encoded::Packet(Packet::PublishAck(ref mut pkt))
275                | Encoded::Packet(Packet::PublishReceived(ref mut pkt)) => {
276                    pkt.properties.clear();
277                    let _ = pkt.reason_string.take();
278                }
279                Encoded::Packet(Packet::PublishRelease(ref mut pkt))
280                | Encoded::Packet(Packet::PublishComplete(ref mut pkt)) => {
281                    pkt.properties.clear();
282                    let _ = pkt.reason_string.take();
283                }
284                Encoded::Packet(Packet::Subscribe(ref mut pkt)) => {
285                    pkt.user_properties.clear();
286                }
287                Encoded::Packet(Packet::SubscribeAck(ref mut pkt)) => {
288                    pkt.properties.clear();
289                    let _ = pkt.reason_string.take();
290                }
291                Encoded::Packet(Packet::Unsubscribe(ref mut pkt)) => {
292                    pkt.user_properties.clear();
293                }
294                Encoded::Packet(Packet::UnsubscribeAck(ref mut pkt)) => {
295                    pkt.properties.clear();
296                    let _ = pkt.reason_string.take();
297                }
298                Encoded::Packet(Packet::Auth(ref mut pkt)) => {
299                    pkt.user_properties.clear();
300                    let _ = pkt.reason_string.take();
301                }
302                _ => (),
303            }
304        }
305
306        let max_out_size = self.max_out_size.get();
307        let max_size = if max_out_size != 0 { max_out_size } else { MAX_PACKET_SIZE };
308        match item {
309            Encoded::Packet(pkt) => {
310                if self.encoding_payload.get().is_some() {
311                    log::trace!("Expect payload, received {:?}", pkt);
312                    Err(EncodeError::ExpectPayload)
313                } else {
314                    let content_size = pkt.encoded_size(max_size);
315                    if content_size > max_size as usize {
316                        Err(EncodeError::OverMaxPacketSize)
317                    } else {
318                        dst.reserve(content_size + 5);
319                        pkt.encode(dst, content_size as u32)?; // safe: max_size <= u32 max value
320                        Ok(())
321                    }
322                }
323            }
324            Encoded::Publish(pkt, buf) => {
325                let content_size = pkt.encoded_size(max_size) as u32;
326                if content_size > max_size {
327                    return Err(EncodeError::OverMaxPacketSize);
328                }
329
330                let total_size = content_size - pkt.payload_size
331                    + buf.as_ref().map(|b| b.len() as u32).unwrap_or(0);
332                dst.reserve((total_size + 5) as usize);
333                pkt.encode(dst, content_size)?; // safe: max_size <= u32 max value
334
335                let remaining = if let Some(buf) = buf {
336                    dst.extend_from_slice(&buf);
337                    pkt.payload_size - buf.len() as u32
338                } else {
339                    pkt.payload_size
340                };
341                self.encoding_payload.set(NonZeroU32::new(remaining as u32));
342                Ok(())
343            }
344            Encoded::PayloadChunk(chunk) => {
345                if let Some(remaining) = self.encoding_payload.get() {
346                    let len = chunk.len() as u32;
347                    if len > remaining.get() {
348                        Err(EncodeError::OverPublishSize)
349                    } else {
350                        dst.extend_from_slice(&chunk);
351                        self.encoding_payload.set(NonZeroU32::new(remaining.get() - len));
352                        Ok(())
353                    }
354                } else {
355                    Err(EncodeError::UnexpectedPayload)
356                }
357            }
358        }
359    }
360}
361
362impl Clone for Codec {
363    fn clone(&self) -> Self {
364        Codec {
365            state: Cell::new(DecodeState::FrameHeader),
366            max_in_size: self.max_in_size.clone(),
367            max_out_size: self.max_out_size.clone(),
368            min_chunk_size: self.min_chunk_size.clone(),
369            flags: Cell::new(CodecFlags::empty()),
370            encoding_payload: Cell::new(None),
371        }
372    }
373}
374
375impl fmt::Debug for Codec {
376    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
377        f.debug_struct("Codec")
378            .field("state", &self.state)
379            .field("max_in_size", &self.max_in_size)
380            .field("max_out_size", &self.max_out_size)
381            .field("min_chunk_size", &self.min_chunk_size)
382            .field("flags", &self.flags)
383            .finish()
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn test_max_size() {
393        let codec = Codec::new();
394        codec.set_max_inbound_size(5);
395        let mut buf = BytesMut::new();
396        buf.extend_from_slice(b"\0\x09");
397        assert_eq!(codec.decode(&mut buf).err(), Some(DecodeError::MaxSizeExceeded));
398    }
399}