Skip to main content

ntex_mqtt/v3/codec/
codec.rs

1use std::{cell::Cell, cmp::min, num::NonZeroU32};
2
3use ntex_bytes::{Buf, BytePages, Bytes, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5
6use crate::error::{DecodeError, EncodeError};
7use crate::types::{FixedHeader, QoS, packet_type};
8use crate::utils::decode_variable_length;
9
10use super::{Decoded, Encoded, Publish, decode, encode};
11
12#[derive(Debug, Clone)]
13/// Mqtt v3.1.1 protocol codec
14pub struct Codec {
15    state: Cell<DecodeState>,
16    max_size: Cell<u32>,
17    min_chunk_size: Cell<u32>,
18    encoding_payload: Cell<Option<NonZeroU32>>,
19}
20
21#[derive(Debug, Copy, Clone, PartialEq, Eq)]
22enum DecodeState {
23    FrameHeader,
24    Frame(FixedHeader),
25    PublishHeader(FixedHeader),
26    PublishPayload(u32),
27}
28
29impl Codec {
30    /// Create `Codec` instance
31    pub fn new() -> Self {
32        Codec {
33            state: Cell::new(DecodeState::FrameHeader),
34            max_size: Cell::new(0),
35            min_chunk_size: Cell::new(0),
36            encoding_payload: Cell::new(None),
37        }
38    }
39
40    /// Set max inbound frame size.
41    ///
42    /// If max size is set to `0`, size is unlimited.
43    /// By default max size is set to `0`
44    pub fn set_max_size(&self, size: u32) {
45        self.max_size.set(size);
46    }
47
48    /// Set min payload chunk size.
49    ///
50    /// If the minimum size is set to `0`, incoming payload chunks
51    /// will be processed immediately. Otherwise, the codec will
52    /// accumulate chunks until the total size reaches the specified minimum.
53    /// By default min size is set to `0`
54    pub fn set_min_chunk_size(&self, size: u32) {
55        self.min_chunk_size.set(size);
56    }
57}
58
59impl Default for Codec {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl Decoder for Codec {
66    type Item = Decoded;
67    type Error = DecodeError;
68
69    #[allow(clippy::too_many_lines)]
70    fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
71        loop {
72            match self.state.get() {
73                DecodeState::FrameHeader => {
74                    if src.len() < 2 {
75                        return Ok(None);
76                    }
77                    let src_slice = src.as_ref();
78                    let first_byte = src_slice[0];
79                    match decode_variable_length(&src_slice[1..])? {
80                        Some((remaining_length, consumed)) => {
81                            // check max message size
82                            let max_size = self.max_size.get();
83                            if max_size != 0 && max_size < remaining_length {
84                                return Err(DecodeError::MaxSizeExceeded {
85                                    size: remaining_length,
86                                    max_size,
87                                });
88                            }
89                            src.advance(consumed + 1);
90
91                            if packet_type::is_publish(first_byte) {
92                                self.state.set(DecodeState::PublishHeader(FixedHeader {
93                                    first_byte,
94                                    remaining_length,
95                                }));
96                            } else {
97                                self.state.set(DecodeState::Frame(FixedHeader {
98                                    first_byte,
99                                    remaining_length,
100                                }));
101                                // todo: validate remaining_length against max frame size config
102                                let remaining_length = remaining_length as usize;
103                                if src.len() < remaining_length {
104                                    // todo: subtract?
105                                    src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager?
106                                    return Ok(None);
107                                }
108                            }
109                        }
110                        None => {
111                            return Ok(None);
112                        }
113                    }
114                }
115                DecodeState::PublishHeader(fixed) => {
116                    if let Some(hdr_len) = decode::publish_size(src, fixed.first_byte)? {
117                        if src.len() < hdr_len as usize {
118                            return Ok(None);
119                        }
120                        let payload_len = fixed.remaining_length - hdr_len;
121                        let mut buf = src.split_to(hdr_len as usize);
122                        let publish = decode::decode_publish_packet(
123                            &mut buf,
124                            fixed.first_byte,
125                            payload_len,
126                        )?;
127
128                        let len = src.len() as u32;
129                        let min_chunk_size = self.min_chunk_size.get();
130                        if len >= payload_len || min_chunk_size == 0 || len >= min_chunk_size {
131                            let payload = src.split_to(min(src.len(), payload_len as usize));
132                            let remaining = payload_len - payload.len() as u32;
133
134                            if remaining > 0 {
135                                self.state.set(DecodeState::PublishPayload(remaining));
136                            } else {
137                                self.state.set(DecodeState::FrameHeader);
138                                src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
139                            }
140
141                            return Ok(Some(Decoded::Publish(
142                                publish,
143                                payload,
144                                fixed.remaining_length,
145                            )));
146                        }
147                        self.state.set(DecodeState::PublishPayload(payload_len));
148                        return Ok(Some(Decoded::Publish(
149                            publish,
150                            Bytes::new(),
151                            fixed.remaining_length,
152                        )));
153                    }
154                    return Ok(None);
155                }
156                DecodeState::PublishPayload(remaining) => {
157                    let len = src.len() as u32;
158                    let min_chunk_size = self.min_chunk_size.get();
159
160                    return if (len >= remaining)
161                        || (min_chunk_size != 0 && len >= min_chunk_size)
162                    {
163                        let payload = src.split_to(min(src.len(), remaining as usize));
164                        let remaining = remaining - payload.len() as u32;
165
166                        let eof = if remaining > 0 {
167                            self.state.set(DecodeState::PublishPayload(remaining));
168                            false
169                        } else {
170                            self.state.set(DecodeState::FrameHeader);
171                            src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
172                            true
173                        };
174                        Ok(Some(Decoded::PayloadChunk(payload, eof)))
175                    } else {
176                        Ok(None)
177                    };
178                }
179                DecodeState::Frame(fixed) => {
180                    if src.len() < fixed.remaining_length as usize {
181                        return Ok(None);
182                    }
183                    let packet_buf = src.split_to(fixed.remaining_length as usize);
184                    let packet = decode::decode_packet(packet_buf, fixed.first_byte)?;
185                    self.state.set(DecodeState::FrameHeader);
186                    src.reserve(2);
187                    return Ok(Some(Decoded::Packet(packet, fixed.remaining_length)));
188                }
189            }
190        }
191    }
192}
193
194impl Encoder for Codec {
195    type Item = Encoded;
196    type Error = EncodeError;
197
198    fn encodev(&self, item: Self::Item, dst: &mut BytePages) -> Result<(), EncodeError> {
199        match item {
200            Encoded::Packet(pkt) => {
201                let content_size = encode::get_encoded_size(&pkt);
202                encode::encode(&pkt, dst, content_size as u32)?;
203                Ok(())
204            }
205            Encoded::Publish(pkt, buf) => {
206                let Publish { qos, packet_id, .. } = pkt;
207                if (qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce) && packet_id.is_none() {
208                    return Err(EncodeError::PacketIdRequired);
209                }
210
211                let content_size = encode::get_encoded_publish_size(&pkt) as u32;
212                if self.max_size.get() != 0 && content_size > self.max_size.get() {
213                    return Err(EncodeError::OverMaxPacketSize);
214                }
215
216                encode::encode_publish(&pkt, dst, content_size)?; // safe: max_size <= u32 max value
217
218                let remaining = if let Some(buf) = buf {
219                    let remaining = pkt.payload_size - buf.len() as u32;
220                    dst.append(buf);
221                    remaining
222                } else {
223                    pkt.payload_size
224                };
225                self.encoding_payload.set(NonZeroU32::new(remaining));
226                Ok(())
227            }
228            Encoded::PayloadChunk(chunk) => {
229                if let Some(remaining) = self.encoding_payload.get() {
230                    let len = chunk.len() as u32;
231                    if len > remaining.get() {
232                        Err(EncodeError::OverPublishSize)
233                    } else {
234                        dst.append(chunk);
235                        self.encoding_payload.set(NonZeroU32::new(remaining.get() - len));
236                        Ok(())
237                    }
238                } else {
239                    Err(EncodeError::UnexpectedPayload)
240                }
241            }
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use ntex_bytes::{ByteString, Bytes};
250
251    #[test]
252    fn test_max_size() {
253        let codec = Codec::new();
254        codec.set_max_size(5);
255
256        let mut buf = BytesMut::new();
257        buf.extend_from_slice(b"\0\x09");
258        assert_eq!(
259            codec.decode(&mut buf),
260            Err(DecodeError::MaxSizeExceeded { size: 9, max_size: 5 })
261        );
262    }
263
264    #[test]
265    fn test_packet() {
266        let codec = Codec::new();
267        let mut buf = BytePages::default();
268
269        let pkt = Publish {
270            dup: false,
271            retain: false,
272            qos: QoS::AtMostOnce,
273            topic: ByteString::from_static("/test"),
274            packet_id: None,
275            payload_size: 260 * 1024,
276        };
277        let payload = Bytes::from(Vec::from("a".repeat(260 * 1024)));
278        codec.encodev(Encoded::Publish(pkt.clone(), Some(payload)), &mut buf).unwrap();
279
280        let Decoded::Publish(pkt2, _, _) =
281            codec.decode(&mut BytesMut::from(buf.freeze())).unwrap().unwrap()
282        else {
283            panic!()
284        };
285        assert_eq!(pkt, pkt2);
286    }
287}