ntex_mqtt/v3/codec/
codec.rs1use std::{cell::Cell, cmp::min, num::NonZeroU32};
2
3use ntex_bytes::{Buf, BytePages, Bytes, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5
6use crate::error::{DecodeError, EncodeError};
7use crate::types::{FixedHeader, QoS, packet_type};
8use crate::utils::decode_variable_length;
9
10use super::{Decoded, Encoded, Publish, decode, encode};
11
12#[derive(Debug, Clone)]
13pub struct Codec {
15 state: Cell<DecodeState>,
16 max_size: Cell<u32>,
17 min_chunk_size: Cell<u32>,
18 encoding_payload: Cell<Option<NonZeroU32>>,
19}
20
21#[derive(Debug, Copy, Clone, PartialEq, Eq)]
22enum DecodeState {
23 FrameHeader,
24 Frame(FixedHeader),
25 PublishHeader(FixedHeader),
26 PublishPayload(u32),
27}
28
29impl Codec {
30 pub fn new() -> Self {
32 Codec {
33 state: Cell::new(DecodeState::FrameHeader),
34 max_size: Cell::new(0),
35 min_chunk_size: Cell::new(0),
36 encoding_payload: Cell::new(None),
37 }
38 }
39
40 pub fn set_max_size(&self, size: u32) {
45 self.max_size.set(size);
46 }
47
48 pub fn set_min_chunk_size(&self, size: u32) {
55 self.min_chunk_size.set(size);
56 }
57}
58
59impl Default for Codec {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl Decoder for Codec {
66 type Item = Decoded;
67 type Error = DecodeError;
68
69 #[allow(clippy::too_many_lines)]
70 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
71 loop {
72 match self.state.get() {
73 DecodeState::FrameHeader => {
74 if src.len() < 2 {
75 return Ok(None);
76 }
77 let src_slice = src.as_ref();
78 let first_byte = src_slice[0];
79 match decode_variable_length(&src_slice[1..])? {
80 Some((remaining_length, consumed)) => {
81 let max_size = self.max_size.get();
83 if max_size != 0 && max_size < remaining_length {
84 return Err(DecodeError::MaxSizeExceeded {
85 size: remaining_length,
86 max_size,
87 });
88 }
89 src.advance(consumed + 1);
90
91 if packet_type::is_publish(first_byte) {
92 self.state.set(DecodeState::PublishHeader(FixedHeader {
93 first_byte,
94 remaining_length,
95 }));
96 } else {
97 self.state.set(DecodeState::Frame(FixedHeader {
98 first_byte,
99 remaining_length,
100 }));
101 let remaining_length = remaining_length as usize;
103 if src.len() < remaining_length {
104 src.reserve(remaining_length); return Ok(None);
107 }
108 }
109 }
110 None => {
111 return Ok(None);
112 }
113 }
114 }
115 DecodeState::PublishHeader(fixed) => {
116 if let Some(hdr_len) = decode::publish_size(src, fixed.first_byte)? {
117 if src.len() < hdr_len as usize {
118 return Ok(None);
119 }
120 let payload_len = fixed.remaining_length - hdr_len;
121 let mut buf = src.split_to(hdr_len as usize);
122 let publish = decode::decode_publish_packet(
123 &mut buf,
124 fixed.first_byte,
125 payload_len,
126 )?;
127
128 let len = src.len() as u32;
129 let min_chunk_size = self.min_chunk_size.get();
130 if len >= payload_len || min_chunk_size == 0 || len >= min_chunk_size {
131 let payload = src.split_to(min(src.len(), payload_len as usize));
132 let remaining = payload_len - payload.len() as u32;
133
134 if remaining > 0 {
135 self.state.set(DecodeState::PublishPayload(remaining));
136 } else {
137 self.state.set(DecodeState::FrameHeader);
138 src.reserve(5); }
140
141 return Ok(Some(Decoded::Publish(
142 publish,
143 payload,
144 fixed.remaining_length,
145 )));
146 }
147 self.state.set(DecodeState::PublishPayload(payload_len));
148 return Ok(Some(Decoded::Publish(
149 publish,
150 Bytes::new(),
151 fixed.remaining_length,
152 )));
153 }
154 return Ok(None);
155 }
156 DecodeState::PublishPayload(remaining) => {
157 let len = src.len() as u32;
158 let min_chunk_size = self.min_chunk_size.get();
159
160 return if (len >= remaining)
161 || (min_chunk_size != 0 && len >= min_chunk_size)
162 {
163 let payload = src.split_to(min(src.len(), remaining as usize));
164 let remaining = remaining - payload.len() as u32;
165
166 let eof = if remaining > 0 {
167 self.state.set(DecodeState::PublishPayload(remaining));
168 false
169 } else {
170 self.state.set(DecodeState::FrameHeader);
171 src.reserve(5); true
173 };
174 Ok(Some(Decoded::PayloadChunk(payload, eof)))
175 } else {
176 Ok(None)
177 };
178 }
179 DecodeState::Frame(fixed) => {
180 if src.len() < fixed.remaining_length as usize {
181 return Ok(None);
182 }
183 let packet_buf = src.split_to(fixed.remaining_length as usize);
184 let packet = decode::decode_packet(packet_buf, fixed.first_byte)?;
185 self.state.set(DecodeState::FrameHeader);
186 src.reserve(2);
187 return Ok(Some(Decoded::Packet(packet, fixed.remaining_length)));
188 }
189 }
190 }
191 }
192}
193
194impl Encoder for Codec {
195 type Item = Encoded;
196 type Error = EncodeError;
197
198 fn encodev(&self, item: Self::Item, dst: &mut BytePages) -> Result<(), EncodeError> {
199 match item {
200 Encoded::Packet(pkt) => {
201 let content_size = encode::get_encoded_size(&pkt);
202 encode::encode(&pkt, dst, content_size as u32)?;
203 Ok(())
204 }
205 Encoded::Publish(pkt, buf) => {
206 let Publish { qos, packet_id, .. } = pkt;
207 if (qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce) && packet_id.is_none() {
208 return Err(EncodeError::PacketIdRequired);
209 }
210
211 let content_size = encode::get_encoded_publish_size(&pkt) as u32;
212 if self.max_size.get() != 0 && content_size > self.max_size.get() {
213 return Err(EncodeError::OverMaxPacketSize);
214 }
215
216 encode::encode_publish(&pkt, dst, content_size)?; let remaining = if let Some(buf) = buf {
219 let remaining = pkt.payload_size - buf.len() as u32;
220 dst.append(buf);
221 remaining
222 } else {
223 pkt.payload_size
224 };
225 self.encoding_payload.set(NonZeroU32::new(remaining));
226 Ok(())
227 }
228 Encoded::PayloadChunk(chunk) => {
229 if let Some(remaining) = self.encoding_payload.get() {
230 let len = chunk.len() as u32;
231 if len > remaining.get() {
232 Err(EncodeError::OverPublishSize)
233 } else {
234 dst.append(chunk);
235 self.encoding_payload.set(NonZeroU32::new(remaining.get() - len));
236 Ok(())
237 }
238 } else {
239 Err(EncodeError::UnexpectedPayload)
240 }
241 }
242 }
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use ntex_bytes::{ByteString, Bytes};
250
251 #[test]
252 fn test_max_size() {
253 let codec = Codec::new();
254 codec.set_max_size(5);
255
256 let mut buf = BytesMut::new();
257 buf.extend_from_slice(b"\0\x09");
258 assert_eq!(
259 codec.decode(&mut buf),
260 Err(DecodeError::MaxSizeExceeded { size: 9, max_size: 5 })
261 );
262 }
263
264 #[test]
265 fn test_packet() {
266 let codec = Codec::new();
267 let mut buf = BytePages::default();
268
269 let pkt = Publish {
270 dup: false,
271 retain: false,
272 qos: QoS::AtMostOnce,
273 topic: ByteString::from_static("/test"),
274 packet_id: None,
275 payload_size: 260 * 1024,
276 };
277 let payload = Bytes::from(Vec::from("a".repeat(260 * 1024)));
278 codec.encodev(Encoded::Publish(pkt.clone(), Some(payload)), &mut buf).unwrap();
279
280 let Decoded::Publish(pkt2, _, _) =
281 codec.decode(&mut BytesMut::from(buf.freeze())).unwrap().unwrap()
282 else {
283 panic!()
284 };
285 assert_eq!(pkt, pkt2);
286 }
287}