ntex_mqtt/v3/codec/
codec.rs1use std::{cell::Cell, cmp::min, num::NonZeroU32};
2
3use ntex_bytes::{Buf, Bytes, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5
6use crate::error::{DecodeError, EncodeError};
7use crate::types::{packet_type, FixedHeader, QoS};
8use crate::utils::decode_variable_length;
9
10use super::{decode, encode, Decoded, Encoded, Packet, Publish};
11
12#[derive(Debug, Clone)]
13pub struct Codec {
15 state: Cell<DecodeState>,
16 max_size: Cell<u32>,
17 min_chunk_size: Cell<u32>,
18 encoding_payload: Cell<Option<NonZeroU32>>,
19}
20
21#[derive(Debug, Copy, Clone, PartialEq, Eq)]
22enum DecodeState {
23 FrameHeader,
24 Frame(FixedHeader),
25 PublishHeader(FixedHeader),
26 PublishPayload(u32),
27}
28
29impl Codec {
30 pub fn new() -> Self {
32 Codec {
33 state: Cell::new(DecodeState::FrameHeader),
34 max_size: Cell::new(0),
35 min_chunk_size: Cell::new(0),
36 encoding_payload: Cell::new(None),
37 }
38 }
39
40 pub fn set_max_size(&self, size: u32) {
45 self.max_size.set(size);
46 }
47
48 pub fn set_min_chunk_size(&self, size: u32) {
55 self.min_chunk_size.set(size)
56 }
57}
58
59impl Default for Codec {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl Decoder for Codec {
66 type Item = Decoded;
67 type Error = DecodeError;
68
69 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
70 loop {
71 match self.state.get() {
72 DecodeState::FrameHeader => {
73 if src.len() < 2 {
74 return Ok(None);
75 }
76 let src_slice = src.as_ref();
77 let first_byte = src_slice[0];
78 match decode_variable_length(&src_slice[1..])? {
79 Some((remaining_length, consumed)) => {
80 let max_size = self.max_size.get();
82 if max_size != 0 && max_size < remaining_length {
83 return Err(DecodeError::MaxSizeExceeded);
84 }
85 src.advance(consumed + 1);
86
87 if packet_type::is_publish(first_byte) {
88 self.state.set(DecodeState::PublishHeader(FixedHeader {
89 first_byte,
90 remaining_length,
91 }));
92 } else {
93 self.state.set(DecodeState::Frame(FixedHeader {
94 first_byte,
95 remaining_length,
96 }));
97 let remaining_length = remaining_length as usize;
99 if src.len() < remaining_length {
100 src.reserve(remaining_length); return Ok(None);
103 }
104 }
105 }
106 None => {
107 return Ok(None);
108 }
109 }
110 }
111 DecodeState::PublishHeader(fixed) => {
112 if let Some(hdr_len) = decode::publish_size(src, fixed.first_byte)? {
113 if src.len() < hdr_len as usize {
114 return Ok(None);
115 }
116 let payload_len = (fixed.remaining_length - hdr_len);
117 let mut buf = src.split_to(hdr_len as usize).freeze();
118 let publish = decode::decode_publish_packet(
119 &mut buf,
120 fixed.first_byte,
121 payload_len,
122 )?;
123
124 let len = src.len() as u32;
125 let min_chunk_size = self.min_chunk_size.get();
126 if len >= payload_len || min_chunk_size == 0 || len >= min_chunk_size {
127 let payload =
128 src.split_to(min(src.len(), payload_len as usize)).freeze();
129 let remaining = payload_len - payload.len() as u32;
130
131 if remaining > 0 {
132 self.state.set(DecodeState::PublishPayload(remaining));
133 } else {
134 self.state.set(DecodeState::FrameHeader);
135 src.reserve(5); }
137
138 return Ok(Some(Decoded::Publish(
139 publish,
140 payload,
141 fixed.remaining_length,
142 )));
143 } else {
144 self.state.set(DecodeState::PublishPayload(payload_len));
145 return Ok(Some(Decoded::Publish(
146 publish,
147 Bytes::new(),
148 fixed.remaining_length,
149 )));
150 }
151 }
152 return Ok(None);
153 }
154 DecodeState::PublishPayload(remaining) => {
155 let len = src.len() as u32;
156 let min_chunk_size = self.min_chunk_size.get();
157
158 return if (len >= remaining)
159 || (min_chunk_size != 0 && len >= min_chunk_size)
160 {
161 let payload = src.split_to(min(src.len(), remaining as usize)).freeze();
162 let remaining = remaining - payload.len() as u32;
163
164 let eof = if remaining > 0 {
165 self.state.set(DecodeState::PublishPayload(remaining));
166 false
167 } else {
168 self.state.set(DecodeState::FrameHeader);
169 src.reserve(5); true
171 };
172 Ok(Some(Decoded::PayloadChunk(payload, eof)))
173 } else {
174 Ok(None)
175 };
176 }
177 DecodeState::Frame(fixed) => {
178 if src.len() < fixed.remaining_length as usize {
179 return Ok(None);
180 }
181 let packet_buf = src.split_to(fixed.remaining_length as usize);
182 let packet = decode::decode_packet(packet_buf.freeze(), fixed.first_byte)?;
183 self.state.set(DecodeState::FrameHeader);
184 src.reserve(2);
185 return Ok(Some(Decoded::Packet(packet, fixed.remaining_length)));
186 }
187 }
188 }
189 }
190}
191
192impl Encoder for Codec {
193 type Item = Encoded;
194 type Error = EncodeError;
195
196 fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), EncodeError> {
197 match item {
198 Encoded::Packet(pkt) => {
199 let content_size = encode::get_encoded_size(&pkt);
200 dst.reserve(content_size + 5);
201 encode::encode(&pkt, dst, content_size as u32)?;
202 Ok(())
203 }
204 Encoded::Publish(pkt, buf) => {
205 if let Publish { qos, packet_id, .. } = pkt {
206 if (qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce)
207 && packet_id.is_none()
208 {
209 return Err(EncodeError::PacketIdRequired);
210 }
211 }
212
213 let content_size = encode::get_encoded_publish_size(&pkt) as u32;
214 if self.max_size.get() != 0 && content_size > self.max_size.get() {
215 return Err(EncodeError::OverMaxPacketSize);
216 }
217
218 let current_size = content_size - pkt.payload_size
219 + buf.as_ref().map(|b| b.len() as u32).unwrap_or(0);
220 dst.reserve((current_size + 5) as usize);
221 encode::encode_publish(&pkt, dst, content_size)?; let remaining = if let Some(buf) = buf {
224 dst.extend_from_slice(&buf);
225 pkt.payload_size - buf.len() as u32
226 } else {
227 pkt.payload_size
228 };
229 self.encoding_payload.set(NonZeroU32::new(remaining as u32));
230 Ok(())
231 }
232 Encoded::PayloadChunk(chunk) => {
233 if let Some(remaining) = self.encoding_payload.get() {
234 let len = chunk.len() as u32;
235 if len > remaining.get() {
236 Err(EncodeError::OverPublishSize)
237 } else {
238 dst.extend_from_slice(&chunk);
239 self.encoding_payload.set(NonZeroU32::new(remaining.get() - len));
240 Ok(())
241 }
242 } else {
243 Err(EncodeError::UnexpectedPayload)
244 }
245 }
246 }
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use ntex_bytes::{ByteString, Bytes};
254
255 #[test]
256 fn test_max_size() {
257 let codec = Codec::new();
258 codec.set_max_size(5);
259
260 let mut buf = BytesMut::new();
261 buf.extend_from_slice(b"\0\x09");
262 assert_eq!(codec.decode(&mut buf), Err(DecodeError::MaxSizeExceeded));
263 }
264
265 #[test]
266 fn test_packet() {
267 let codec = Codec::new();
268 let mut buf = BytesMut::new();
269
270 let pkt = Publish {
271 dup: false,
272 retain: false,
273 qos: QoS::AtMostOnce,
274 topic: ByteString::from_static("/test"),
275 packet_id: None,
276 payload_size: 260 * 1024,
277 };
278 let payload = Bytes::from(Vec::from("a".repeat(260 * 1024)));
279 codec.encode(Encoded::Publish(pkt.clone(), Some(payload)), &mut buf).unwrap();
280
281 let pkt2 = if let (Decoded::Publish(v, _, _)) = codec.decode(&mut buf).unwrap().unwrap()
282 {
283 v
284 } else {
285 panic!()
286 };
287 assert_eq!(pkt, pkt2);
288 }
289}