Skip to main content

ntex_mqtt/v5/codec/
codec.rs

1use std::{cell::Cell, cmp::min, fmt, num::NonZeroU32};
2
3use ntex_bytes::{Buf, BytePages, Bytes, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5
6use crate::error::{DecodeError, EncodeError};
7use crate::types::{FixedHeader, MAX_PACKET_SIZE, packet_type};
8use crate::utils::decode_variable_length;
9
10use super::{Decoded, Encoded};
11use super::{Packet, decode::decode_packet, encode::EncodeLtd, packet::Publish};
12
13pub struct Codec {
14    state: Cell<DecodeState>,
15    max_in_size: Cell<u32>,
16    max_out_size: Cell<u32>,
17    min_chunk_size: Cell<u32>,
18    flags: Cell<CodecFlags>,
19    encoding_payload: Cell<Option<NonZeroU32>>,
20}
21
22bitflags::bitflags! {
23    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
24    pub struct CodecFlags: u8 {
25        const NO_PROBLEM_INFO = 0b0000_0001;
26        const NO_RETAIN       = 0b0000_0010;
27        const NO_SUB_IDS      = 0b0000_1000;
28    }
29}
30
31#[derive(Debug, Clone, Copy)]
32enum DecodeState {
33    FrameHeader,
34    Frame(FixedHeader),
35    PublishHeader(FixedHeader),
36    PublishProperties(u32, FixedHeader),
37    PublishPayload(u32),
38}
39
40impl Codec {
41    /// Create `Codec` instance
42    pub fn new() -> Self {
43        Codec {
44            state: Cell::new(DecodeState::FrameHeader),
45            max_in_size: Cell::new(0),
46            max_out_size: Cell::new(0),
47            min_chunk_size: Cell::new(0),
48            flags: Cell::new(CodecFlags::empty()),
49            encoding_payload: Cell::new(None),
50        }
51    }
52
53    /// Set min payload chunk size.
54    ///
55    /// If the minimum size is set to `0`, incoming payload chunks
56    /// will be processed immediately. Otherwise, the codec will
57    /// accumulate chunks until the total size reaches the specified minimum.
58    /// By default min size is set to `0`
59    pub fn set_min_chunk_size(&self, size: u32) {
60        self.min_chunk_size.set(size);
61    }
62
63    /// Set max inbound frame size.
64    ///
65    /// If max size is set to `0`, size is unlimited.
66    /// By default max size is set to `0`
67    pub fn max_inbound_size(&self) -> u32 {
68        self.max_in_size.get()
69    }
70
71    /// Set max outbound frame size.
72    ///
73    /// If max size is set to `0`, size is unlimited.
74    /// By default max size is set to `0`
75    pub fn max_outbound_size(&self) -> u32 {
76        self.max_out_size.get()
77    }
78
79    /// Set max inbound frame size.
80    ///
81    /// If max size is set to `0`, size is unlimited.
82    /// By default max size is set to `0`
83    pub fn set_max_inbound_size(&self, size: u32) {
84        self.max_in_size.set(size);
85    }
86
87    /// Set max outbound frame size.
88    ///
89    /// If max size is set to `0`, size is unlimited.
90    /// By default max size is set to `0`
91    pub fn set_max_outbound_size(&self, mut size: u32) {
92        if size > 5 {
93            // fixed header = 1, var_len(remaining.max_value()) = 4
94            size -= 5;
95        }
96        self.max_out_size.set(size);
97    }
98
99    pub(crate) fn retain_available(&self) -> bool {
100        !self.flags.get().contains(CodecFlags::NO_RETAIN)
101    }
102
103    pub(crate) fn sub_ids_available(&self) -> bool {
104        !self.flags.get().contains(CodecFlags::NO_SUB_IDS)
105    }
106
107    pub(crate) fn set_retain_available(&self, val: bool) {
108        let mut flags = self.flags.get();
109        flags.set(CodecFlags::NO_RETAIN, !val);
110        self.flags.set(flags);
111    }
112
113    pub(crate) fn set_sub_ids_available(&self, val: bool) {
114        let mut flags = self.flags.get();
115        flags.set(CodecFlags::NO_SUB_IDS, !val);
116        self.flags.set(flags);
117    }
118}
119
120impl Default for Codec {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl Decoder for Codec {
127    type Item = super::Decoded;
128    type Error = DecodeError;
129
130    #[allow(clippy::too_many_lines)]
131    fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
132        loop {
133            match self.state.get() {
134                DecodeState::FrameHeader => {
135                    if src.len() < 2 {
136                        return Ok(None);
137                    }
138                    let src_slice = src.as_ref();
139                    let first_byte = src_slice[0];
140                    match decode_variable_length(&src_slice[1..])? {
141                        Some((remaining_length, consumed)) => {
142                            // check max message size
143                            let max_in_size = self.max_in_size.get();
144                            if max_in_size != 0 && max_in_size < remaining_length {
145                                log::debug!(
146                                    "MaxSizeExceeded max-size: {max_in_size}, remaining: {remaining_length}"
147                                );
148                                return Err(DecodeError::MaxSizeExceeded {
149                                    size: remaining_length,
150                                    max_size: max_in_size,
151                                });
152                            }
153                            src.advance(consumed + 1);
154
155                            if packet_type::is_publish(first_byte) {
156                                self.state.set(DecodeState::PublishHeader(FixedHeader {
157                                    first_byte,
158                                    remaining_length,
159                                }));
160                            } else {
161                                self.state.set(DecodeState::Frame(FixedHeader {
162                                    first_byte,
163                                    remaining_length,
164                                }));
165
166                                // todo: validate remaining_length against max frame size config
167                                let remaining_length = remaining_length as usize;
168                                if src.len() < remaining_length {
169                                    // todo: subtract?
170                                    src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager?
171                                    return Ok(None);
172                                }
173                            }
174                        }
175                        None => {
176                            return Ok(None);
177                        }
178                    }
179                }
180                DecodeState::PublishHeader(fixed) => {
181                    if let Some(len) = Publish::packet_header_size(src, fixed.first_byte)? {
182                        self.state.set(DecodeState::PublishProperties(len, fixed));
183                    } else {
184                        return Ok(None);
185                    }
186                }
187                DecodeState::PublishProperties(props_len, fixed) => {
188                    if src.len() < props_len as usize {
189                        return Ok(None);
190                    }
191                    let payload_len = fixed.remaining_length - props_len;
192                    let mut buf = src.split_to(props_len as usize);
193                    let publish = Publish::decode(&mut buf, fixed.first_byte, payload_len)?;
194
195                    let len = src.len() as u32;
196                    let min_chunk_size = self.min_chunk_size.get();
197                    return if len >= payload_len || min_chunk_size == 0 || len >= min_chunk_size
198                    {
199                        let payload = src.split_to(min(src.len(), payload_len as usize));
200                        let remaining = payload_len - payload.len() as u32;
201
202                        if remaining > 0 {
203                            self.state.set(DecodeState::PublishPayload(remaining));
204                        } else {
205                            self.state.set(DecodeState::FrameHeader);
206                            src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
207                        }
208
209                        Ok(Some(Decoded::Publish(publish, payload, fixed.remaining_length)))
210                    } else {
211                        self.state.set(DecodeState::PublishPayload(payload_len));
212                        Ok(Some(Decoded::Publish(
213                            publish,
214                            Bytes::new(),
215                            fixed.remaining_length,
216                        )))
217                    };
218                }
219                DecodeState::PublishPayload(remaining) => {
220                    let len = src.len() as u32;
221                    let min_chunk_size = self.min_chunk_size.get();
222
223                    return if (len >= remaining)
224                        || (min_chunk_size != 0 && len >= min_chunk_size)
225                    {
226                        let payload = src.split_to(min(src.len(), remaining as usize));
227                        let remaining = remaining - payload.len() as u32;
228
229                        let eof = if remaining > 0 {
230                            self.state.set(DecodeState::PublishPayload(remaining));
231                            false
232                        } else {
233                            self.state.set(DecodeState::FrameHeader);
234                            src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
235                            true
236                        };
237                        Ok(Some(Decoded::PayloadChunk(payload, eof)))
238                    } else {
239                        Ok(None)
240                    };
241                }
242                DecodeState::Frame(fixed) => {
243                    return if src.len() < fixed.remaining_length as usize {
244                        Ok(None)
245                    } else {
246                        let packet_buf = src.split_to(fixed.remaining_length as usize);
247                        let packet = decode_packet(packet_buf, fixed.first_byte)?;
248                        self.state.set(DecodeState::FrameHeader);
249                        src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
250
251                        if let Packet::Connect(ref pkt) = packet {
252                            let mut flags = self.flags.get();
253                            flags.set(CodecFlags::NO_PROBLEM_INFO, !pkt.request_problem_info);
254                            self.flags.set(flags);
255                        }
256                        Ok(Some(Decoded::Packet(packet, fixed.remaining_length)))
257                    };
258                }
259            }
260        }
261    }
262}
263
264impl Encoder for Codec {
265    type Item = Encoded;
266    type Error = EncodeError;
267
268    fn encodev(&self, mut item: Self::Item, dst: &mut BytePages) -> Result<(), EncodeError> {
269        // handle [MQTT 3.1.2.11.7]
270        if self.flags.get().contains(CodecFlags::NO_PROBLEM_INFO) {
271            match item {
272                Encoded::Packet(
273                    Packet::PublishAck(ref mut pkt) | Packet::PublishReceived(ref mut pkt),
274                ) => {
275                    pkt.properties.clear();
276                    let _ = pkt.reason_string.take();
277                }
278                Encoded::Packet(
279                    Packet::PublishRelease(ref mut pkt) | Packet::PublishComplete(ref mut pkt),
280                ) => {
281                    pkt.properties.clear();
282                    let _ = pkt.reason_string.take();
283                }
284                Encoded::Packet(Packet::Subscribe(ref mut pkt)) => {
285                    pkt.user_properties.clear();
286                }
287                Encoded::Packet(Packet::SubscribeAck(ref mut pkt)) => {
288                    pkt.properties.clear();
289                    let _ = pkt.reason_string.take();
290                }
291                Encoded::Packet(Packet::Unsubscribe(ref mut pkt)) => {
292                    pkt.user_properties.clear();
293                }
294                Encoded::Packet(Packet::UnsubscribeAck(ref mut pkt)) => {
295                    pkt.properties.clear();
296                    let _ = pkt.reason_string.take();
297                }
298                Encoded::Packet(Packet::Auth(ref mut pkt)) => {
299                    pkt.user_properties.clear();
300                    let _ = pkt.reason_string.take();
301                }
302                _ => (),
303            }
304        }
305
306        let max_out_size = self.max_out_size.get();
307        let max_size = if max_out_size != 0 {
308            max_out_size
309        } else {
310            MAX_PACKET_SIZE
311        };
312        match item {
313            Encoded::Packet(pkt) => {
314                if self.encoding_payload.get().is_some() {
315                    log::trace!("Expect payload, received {pkt:?}");
316                    Err(EncodeError::ExpectPayload)
317                } else {
318                    let content_size = pkt.encoded_size(max_size);
319                    if content_size > max_size as usize {
320                        Err(EncodeError::OverMaxPacketSize)
321                    } else {
322                        pkt.encode(dst, content_size as u32)?; // safe: max_size <= u32 max value
323                        Ok(())
324                    }
325                }
326            }
327            Encoded::Publish(pkt, buf) => {
328                let content_size = pkt.encoded_size(max_size) as u32;
329                if content_size > max_size {
330                    return Err(EncodeError::OverMaxPacketSize);
331                }
332
333                pkt.encode(dst, content_size)?; // safe: max_size <= u32 max value
334
335                let remaining = if let Some(buf) = buf {
336                    let remaining = pkt.payload_size - buf.len() as u32;
337                    dst.append(buf);
338                    remaining
339                } else {
340                    pkt.payload_size
341                };
342                self.encoding_payload.set(NonZeroU32::new(remaining));
343                Ok(())
344            }
345            Encoded::PayloadChunk(chunk) => {
346                if let Some(remaining) = self.encoding_payload.get() {
347                    let len = chunk.len() as u32;
348                    if len > remaining.get() {
349                        Err(EncodeError::OverPublishSize)
350                    } else {
351                        dst.append(chunk);
352                        self.encoding_payload.set(NonZeroU32::new(remaining.get() - len));
353                        Ok(())
354                    }
355                } else {
356                    Err(EncodeError::UnexpectedPayload)
357                }
358            }
359        }
360    }
361}
362
363impl Clone for Codec {
364    fn clone(&self) -> Self {
365        Codec {
366            state: Cell::new(DecodeState::FrameHeader),
367            max_in_size: self.max_in_size.clone(),
368            max_out_size: self.max_out_size.clone(),
369            min_chunk_size: self.min_chunk_size.clone(),
370            flags: Cell::new(CodecFlags::empty()),
371            encoding_payload: Cell::new(None),
372        }
373    }
374}
375
376impl fmt::Debug for Codec {
377    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378        f.debug_struct("Codec")
379            .field("state", &self.state)
380            .field("max_in_size", &self.max_in_size)
381            .field("max_out_size", &self.max_out_size)
382            .field("min_chunk_size", &self.min_chunk_size)
383            .field("flags", &self.flags)
384            .finish()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_max_size() {
394        let codec = Codec::new();
395        codec.set_max_inbound_size(5);
396        let mut buf = BytesMut::new();
397        buf.extend_from_slice(b"\0\x09");
398        assert_eq!(
399            codec.decode(&mut buf).err(),
400            Some(DecodeError::MaxSizeExceeded { size: 9, max_size: 5 })
401        );
402    }
403}