ntex_mqtt/v3/codec/
codec.rs

1use std::{cell::Cell, cmp::min, num::NonZeroU32};
2
3use ntex_bytes::{Buf, Bytes, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5
6use crate::error::{DecodeError, EncodeError};
7use crate::types::{packet_type, FixedHeader, QoS};
8use crate::utils::decode_variable_length;
9
10use super::{decode, encode, Decoded, Encoded, Packet, Publish};
11
12#[derive(Debug, Clone)]
13/// Mqtt v3.1.1 protocol codec
14pub struct Codec {
15    state: Cell<DecodeState>,
16    max_size: Cell<u32>,
17    min_chunk_size: Cell<u32>,
18    encoding_payload: Cell<Option<NonZeroU32>>,
19}
20
21#[derive(Debug, Copy, Clone, PartialEq, Eq)]
22enum DecodeState {
23    FrameHeader,
24    Frame(FixedHeader),
25    PublishHeader(FixedHeader),
26    PublishPayload(u32),
27}
28
29impl Codec {
30    /// Create `Codec` instance
31    pub fn new() -> Self {
32        Codec {
33            state: Cell::new(DecodeState::FrameHeader),
34            max_size: Cell::new(0),
35            min_chunk_size: Cell::new(0),
36            encoding_payload: Cell::new(None),
37        }
38    }
39
40    /// Set max inbound frame size.
41    ///
42    /// If max size is set to `0`, size is unlimited.
43    /// By default max size is set to `0`
44    pub fn set_max_size(&self, size: u32) {
45        self.max_size.set(size);
46    }
47
48    /// Set min payload chunk size.
49    ///
50    /// If the minimum size is set to `0`, incoming payload chunks
51    /// will be processed immediately. Otherwise, the codec will
52    /// accumulate chunks until the total size reaches the specified minimum.
53    /// By default min size is set to `0`
54    pub fn set_min_chunk_size(&self, size: u32) {
55        self.min_chunk_size.set(size)
56    }
57}
58
59impl Default for Codec {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl Decoder for Codec {
66    type Item = Decoded;
67    type Error = DecodeError;
68
69    fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
70        loop {
71            match self.state.get() {
72                DecodeState::FrameHeader => {
73                    if src.len() < 2 {
74                        return Ok(None);
75                    }
76                    let src_slice = src.as_ref();
77                    let first_byte = src_slice[0];
78                    match decode_variable_length(&src_slice[1..])? {
79                        Some((remaining_length, consumed)) => {
80                            // check max message size
81                            let max_size = self.max_size.get();
82                            if max_size != 0 && max_size < remaining_length {
83                                return Err(DecodeError::MaxSizeExceeded);
84                            }
85                            src.advance(consumed + 1);
86
87                            if packet_type::is_publish(first_byte) {
88                                self.state.set(DecodeState::PublishHeader(FixedHeader {
89                                    first_byte,
90                                    remaining_length,
91                                }));
92                            } else {
93                                self.state.set(DecodeState::Frame(FixedHeader {
94                                    first_byte,
95                                    remaining_length,
96                                }));
97                                // todo: validate remaining_length against max frame size config
98                                let remaining_length = remaining_length as usize;
99                                if src.len() < remaining_length {
100                                    // todo: subtract?
101                                    src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager?
102                                    return Ok(None);
103                                }
104                            }
105                        }
106                        None => {
107                            return Ok(None);
108                        }
109                    }
110                }
111                DecodeState::PublishHeader(fixed) => {
112                    if let Some(hdr_len) = decode::publish_size(src, fixed.first_byte)? {
113                        if src.len() < hdr_len as usize {
114                            return Ok(None);
115                        }
116                        let payload_len = (fixed.remaining_length - hdr_len);
117                        let mut buf = src.split_to(hdr_len as usize).freeze();
118                        let publish = decode::decode_publish_packet(
119                            &mut buf,
120                            fixed.first_byte,
121                            payload_len,
122                        )?;
123
124                        let len = src.len() as u32;
125                        let min_chunk_size = self.min_chunk_size.get();
126                        if len >= payload_len || min_chunk_size == 0 || len >= min_chunk_size {
127                            let payload =
128                                src.split_to(min(src.len(), payload_len as usize)).freeze();
129                            let remaining = payload_len - payload.len() as u32;
130
131                            if remaining > 0 {
132                                self.state.set(DecodeState::PublishPayload(remaining));
133                            } else {
134                                self.state.set(DecodeState::FrameHeader);
135                                src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
136                            }
137
138                            return Ok(Some(Decoded::Publish(
139                                publish,
140                                payload,
141                                fixed.remaining_length,
142                            )));
143                        } else {
144                            self.state.set(DecodeState::PublishPayload(payload_len));
145                            return Ok(Some(Decoded::Publish(
146                                publish,
147                                Bytes::new(),
148                                fixed.remaining_length,
149                            )));
150                        }
151                    }
152                    return Ok(None);
153                }
154                DecodeState::PublishPayload(remaining) => {
155                    let len = src.len() as u32;
156                    let min_chunk_size = self.min_chunk_size.get();
157
158                    return if (len >= remaining)
159                        || (min_chunk_size != 0 && len >= min_chunk_size)
160                    {
161                        let payload = src.split_to(min(src.len(), remaining as usize)).freeze();
162                        let remaining = remaining - payload.len() as u32;
163
164                        let eof = if remaining > 0 {
165                            self.state.set(DecodeState::PublishPayload(remaining));
166                            false
167                        } else {
168                            self.state.set(DecodeState::FrameHeader);
169                            src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
170                            true
171                        };
172                        Ok(Some(Decoded::PayloadChunk(payload, eof)))
173                    } else {
174                        Ok(None)
175                    };
176                }
177                DecodeState::Frame(fixed) => {
178                    if src.len() < fixed.remaining_length as usize {
179                        return Ok(None);
180                    }
181                    let packet_buf = src.split_to(fixed.remaining_length as usize);
182                    let packet = decode::decode_packet(packet_buf.freeze(), fixed.first_byte)?;
183                    self.state.set(DecodeState::FrameHeader);
184                    src.reserve(2);
185                    return Ok(Some(Decoded::Packet(packet, fixed.remaining_length)));
186                }
187            }
188        }
189    }
190}
191
192impl Encoder for Codec {
193    type Item = Encoded;
194    type Error = EncodeError;
195
196    fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), EncodeError> {
197        match item {
198            Encoded::Packet(pkt) => {
199                let content_size = encode::get_encoded_size(&pkt);
200                dst.reserve(content_size + 5);
201                encode::encode(&pkt, dst, content_size as u32)?;
202                Ok(())
203            }
204            Encoded::Publish(pkt, buf) => {
205                if let Publish { qos, packet_id, .. } = pkt {
206                    if (qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce)
207                        && packet_id.is_none()
208                    {
209                        return Err(EncodeError::PacketIdRequired);
210                    }
211                }
212
213                let content_size = encode::get_encoded_publish_size(&pkt) as u32;
214                if self.max_size.get() != 0 && content_size > self.max_size.get() {
215                    return Err(EncodeError::OverMaxPacketSize);
216                }
217
218                let current_size = content_size - pkt.payload_size
219                    + buf.as_ref().map(|b| b.len() as u32).unwrap_or(0);
220                dst.reserve((current_size + 5) as usize);
221                encode::encode_publish(&pkt, dst, content_size)?; // safe: max_size <= u32 max value
222
223                let remaining = if let Some(buf) = buf {
224                    dst.extend_from_slice(&buf);
225                    pkt.payload_size - buf.len() as u32
226                } else {
227                    pkt.payload_size
228                };
229                self.encoding_payload.set(NonZeroU32::new(remaining as u32));
230                Ok(())
231            }
232            Encoded::PayloadChunk(chunk) => {
233                if let Some(remaining) = self.encoding_payload.get() {
234                    let len = chunk.len() as u32;
235                    if len > remaining.get() {
236                        Err(EncodeError::OverPublishSize)
237                    } else {
238                        dst.extend_from_slice(&chunk);
239                        self.encoding_payload.set(NonZeroU32::new(remaining.get() - len));
240                        Ok(())
241                    }
242                } else {
243                    Err(EncodeError::UnexpectedPayload)
244                }
245            }
246        }
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use ntex_bytes::{ByteString, Bytes};
254
255    #[test]
256    fn test_max_size() {
257        let codec = Codec::new();
258        codec.set_max_size(5);
259
260        let mut buf = BytesMut::new();
261        buf.extend_from_slice(b"\0\x09");
262        assert_eq!(codec.decode(&mut buf), Err(DecodeError::MaxSizeExceeded));
263    }
264
265    #[test]
266    fn test_packet() {
267        let codec = Codec::new();
268        let mut buf = BytesMut::new();
269
270        let pkt = Publish {
271            dup: false,
272            retain: false,
273            qos: QoS::AtMostOnce,
274            topic: ByteString::from_static("/test"),
275            packet_id: None,
276            payload_size: 260 * 1024,
277        };
278        let payload = Bytes::from(Vec::from("a".repeat(260 * 1024)));
279        codec.encode(Encoded::Publish(pkt.clone(), Some(payload)), &mut buf).unwrap();
280
281        let pkt2 = if let (Decoded::Publish(v, _, _)) = codec.decode(&mut buf).unwrap().unwrap()
282        {
283            v
284        } else {
285            panic!()
286        };
287        assert_eq!(pkt, pkt2);
288    }
289}