1use std::{cell::Cell, cmp::min, fmt, num::NonZeroU32};
2
3use ntex_bytes::{Buf, BytePages, Bytes, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5
6use crate::error::{DecodeError, EncodeError};
7use crate::types::{FixedHeader, MAX_PACKET_SIZE, packet_type};
8use crate::utils::decode_variable_length;
9
10use super::{Decoded, Encoded};
11use super::{Packet, decode::decode_packet, encode::EncodeLtd, packet::Publish};
12
13pub struct Codec {
14 state: Cell<DecodeState>,
15 max_in_size: Cell<u32>,
16 max_out_size: Cell<u32>,
17 min_chunk_size: Cell<u32>,
18 flags: Cell<CodecFlags>,
19 encoding_payload: Cell<Option<NonZeroU32>>,
20}
21
22bitflags::bitflags! {
23 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
24 pub struct CodecFlags: u8 {
25 const NO_PROBLEM_INFO = 0b0000_0001;
26 const NO_RETAIN = 0b0000_0010;
27 const NO_SUB_IDS = 0b0000_1000;
28 }
29}
30
31#[derive(Debug, Clone, Copy)]
32enum DecodeState {
33 FrameHeader,
34 Frame(FixedHeader),
35 PublishHeader(FixedHeader),
36 PublishProperties(u32, FixedHeader),
37 PublishPayload(u32),
38}
39
40impl Codec {
41 pub fn new() -> Self {
43 Codec {
44 state: Cell::new(DecodeState::FrameHeader),
45 max_in_size: Cell::new(0),
46 max_out_size: Cell::new(0),
47 min_chunk_size: Cell::new(0),
48 flags: Cell::new(CodecFlags::empty()),
49 encoding_payload: Cell::new(None),
50 }
51 }
52
53 pub fn set_min_chunk_size(&self, size: u32) {
60 self.min_chunk_size.set(size);
61 }
62
63 pub fn max_inbound_size(&self) -> u32 {
68 self.max_in_size.get()
69 }
70
71 pub fn max_outbound_size(&self) -> u32 {
76 self.max_out_size.get()
77 }
78
79 pub fn set_max_inbound_size(&self, size: u32) {
84 self.max_in_size.set(size);
85 }
86
87 pub fn set_max_outbound_size(&self, mut size: u32) {
92 if size > 5 {
93 size -= 5;
95 }
96 self.max_out_size.set(size);
97 }
98
99 pub(crate) fn retain_available(&self) -> bool {
100 !self.flags.get().contains(CodecFlags::NO_RETAIN)
101 }
102
103 pub(crate) fn sub_ids_available(&self) -> bool {
104 !self.flags.get().contains(CodecFlags::NO_SUB_IDS)
105 }
106
107 pub(crate) fn set_retain_available(&self, val: bool) {
108 let mut flags = self.flags.get();
109 flags.set(CodecFlags::NO_RETAIN, !val);
110 self.flags.set(flags);
111 }
112
113 pub(crate) fn set_sub_ids_available(&self, val: bool) {
114 let mut flags = self.flags.get();
115 flags.set(CodecFlags::NO_SUB_IDS, !val);
116 self.flags.set(flags);
117 }
118}
119
120impl Default for Codec {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl Decoder for Codec {
127 type Item = super::Decoded;
128 type Error = DecodeError;
129
130 #[allow(clippy::too_many_lines)]
131 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
132 loop {
133 match self.state.get() {
134 DecodeState::FrameHeader => {
135 if src.len() < 2 {
136 return Ok(None);
137 }
138 let src_slice = src.as_ref();
139 let first_byte = src_slice[0];
140 match decode_variable_length(&src_slice[1..])? {
141 Some((remaining_length, consumed)) => {
142 let max_in_size = self.max_in_size.get();
144 if max_in_size != 0 && max_in_size < remaining_length {
145 log::debug!(
146 "MaxSizeExceeded max-size: {max_in_size}, remaining: {remaining_length}"
147 );
148 return Err(DecodeError::MaxSizeExceeded {
149 size: remaining_length,
150 max_size: max_in_size,
151 });
152 }
153 src.advance(consumed + 1);
154
155 if packet_type::is_publish(first_byte) {
156 self.state.set(DecodeState::PublishHeader(FixedHeader {
157 first_byte,
158 remaining_length,
159 }));
160 } else {
161 self.state.set(DecodeState::Frame(FixedHeader {
162 first_byte,
163 remaining_length,
164 }));
165
166 let remaining_length = remaining_length as usize;
168 if src.len() < remaining_length {
169 src.reserve(remaining_length); return Ok(None);
172 }
173 }
174 }
175 None => {
176 return Ok(None);
177 }
178 }
179 }
180 DecodeState::PublishHeader(fixed) => {
181 if let Some(len) = Publish::packet_header_size(src, fixed.first_byte)? {
182 self.state.set(DecodeState::PublishProperties(len, fixed));
183 } else {
184 return Ok(None);
185 }
186 }
187 DecodeState::PublishProperties(props_len, fixed) => {
188 if src.len() < props_len as usize {
189 return Ok(None);
190 }
191 let payload_len = fixed.remaining_length - props_len;
192 let mut buf = src.split_to(props_len as usize);
193 let publish = Publish::decode(&mut buf, fixed.first_byte, payload_len)?;
194
195 let len = src.len() as u32;
196 let min_chunk_size = self.min_chunk_size.get();
197 return if len >= payload_len || min_chunk_size == 0 || len >= min_chunk_size
198 {
199 let payload = src.split_to(min(src.len(), payload_len as usize));
200 let remaining = payload_len - payload.len() as u32;
201
202 if remaining > 0 {
203 self.state.set(DecodeState::PublishPayload(remaining));
204 } else {
205 self.state.set(DecodeState::FrameHeader);
206 src.reserve(5); }
208
209 Ok(Some(Decoded::Publish(publish, payload, fixed.remaining_length)))
210 } else {
211 self.state.set(DecodeState::PublishPayload(payload_len));
212 Ok(Some(Decoded::Publish(
213 publish,
214 Bytes::new(),
215 fixed.remaining_length,
216 )))
217 };
218 }
219 DecodeState::PublishPayload(remaining) => {
220 let len = src.len() as u32;
221 let min_chunk_size = self.min_chunk_size.get();
222
223 return if (len >= remaining)
224 || (min_chunk_size != 0 && len >= min_chunk_size)
225 {
226 let payload = src.split_to(min(src.len(), remaining as usize));
227 let remaining = remaining - payload.len() as u32;
228
229 let eof = if remaining > 0 {
230 self.state.set(DecodeState::PublishPayload(remaining));
231 false
232 } else {
233 self.state.set(DecodeState::FrameHeader);
234 src.reserve(5); true
236 };
237 Ok(Some(Decoded::PayloadChunk(payload, eof)))
238 } else {
239 Ok(None)
240 };
241 }
242 DecodeState::Frame(fixed) => {
243 return if src.len() < fixed.remaining_length as usize {
244 Ok(None)
245 } else {
246 let packet_buf = src.split_to(fixed.remaining_length as usize);
247 let packet = decode_packet(packet_buf, fixed.first_byte)?;
248 self.state.set(DecodeState::FrameHeader);
249 src.reserve(5); if let Packet::Connect(ref pkt) = packet {
252 let mut flags = self.flags.get();
253 flags.set(CodecFlags::NO_PROBLEM_INFO, !pkt.request_problem_info);
254 self.flags.set(flags);
255 }
256 Ok(Some(Decoded::Packet(packet, fixed.remaining_length)))
257 };
258 }
259 }
260 }
261 }
262}
263
264impl Encoder for Codec {
265 type Item = Encoded;
266 type Error = EncodeError;
267
268 fn encodev(&self, mut item: Self::Item, dst: &mut BytePages) -> Result<(), EncodeError> {
269 if self.flags.get().contains(CodecFlags::NO_PROBLEM_INFO) {
271 match item {
272 Encoded::Packet(
273 Packet::PublishAck(ref mut pkt) | Packet::PublishReceived(ref mut pkt),
274 ) => {
275 pkt.properties.clear();
276 let _ = pkt.reason_string.take();
277 }
278 Encoded::Packet(
279 Packet::PublishRelease(ref mut pkt) | Packet::PublishComplete(ref mut pkt),
280 ) => {
281 pkt.properties.clear();
282 let _ = pkt.reason_string.take();
283 }
284 Encoded::Packet(Packet::Subscribe(ref mut pkt)) => {
285 pkt.user_properties.clear();
286 }
287 Encoded::Packet(Packet::SubscribeAck(ref mut pkt)) => {
288 pkt.properties.clear();
289 let _ = pkt.reason_string.take();
290 }
291 Encoded::Packet(Packet::Unsubscribe(ref mut pkt)) => {
292 pkt.user_properties.clear();
293 }
294 Encoded::Packet(Packet::UnsubscribeAck(ref mut pkt)) => {
295 pkt.properties.clear();
296 let _ = pkt.reason_string.take();
297 }
298 Encoded::Packet(Packet::Auth(ref mut pkt)) => {
299 pkt.user_properties.clear();
300 let _ = pkt.reason_string.take();
301 }
302 _ => (),
303 }
304 }
305
306 let max_out_size = self.max_out_size.get();
307 let max_size = if max_out_size != 0 {
308 max_out_size
309 } else {
310 MAX_PACKET_SIZE
311 };
312 match item {
313 Encoded::Packet(pkt) => {
314 if self.encoding_payload.get().is_some() {
315 log::trace!("Expect payload, received {pkt:?}");
316 Err(EncodeError::ExpectPayload)
317 } else {
318 let content_size = pkt.encoded_size(max_size);
319 if content_size > max_size as usize {
320 Err(EncodeError::OverMaxPacketSize)
321 } else {
322 pkt.encode(dst, content_size as u32)?; Ok(())
324 }
325 }
326 }
327 Encoded::Publish(pkt, buf) => {
328 let content_size = pkt.encoded_size(max_size) as u32;
329 if content_size > max_size {
330 return Err(EncodeError::OverMaxPacketSize);
331 }
332
333 pkt.encode(dst, content_size)?; let remaining = if let Some(buf) = buf {
336 let remaining = pkt.payload_size - buf.len() as u32;
337 dst.append(buf);
338 remaining
339 } else {
340 pkt.payload_size
341 };
342 self.encoding_payload.set(NonZeroU32::new(remaining));
343 Ok(())
344 }
345 Encoded::PayloadChunk(chunk) => {
346 if let Some(remaining) = self.encoding_payload.get() {
347 let len = chunk.len() as u32;
348 if len > remaining.get() {
349 Err(EncodeError::OverPublishSize)
350 } else {
351 dst.append(chunk);
352 self.encoding_payload.set(NonZeroU32::new(remaining.get() - len));
353 Ok(())
354 }
355 } else {
356 Err(EncodeError::UnexpectedPayload)
357 }
358 }
359 }
360 }
361}
362
363impl Clone for Codec {
364 fn clone(&self) -> Self {
365 Codec {
366 state: Cell::new(DecodeState::FrameHeader),
367 max_in_size: self.max_in_size.clone(),
368 max_out_size: self.max_out_size.clone(),
369 min_chunk_size: self.min_chunk_size.clone(),
370 flags: Cell::new(CodecFlags::empty()),
371 encoding_payload: Cell::new(None),
372 }
373 }
374}
375
376impl fmt::Debug for Codec {
377 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378 f.debug_struct("Codec")
379 .field("state", &self.state)
380 .field("max_in_size", &self.max_in_size)
381 .field("max_out_size", &self.max_out_size)
382 .field("min_chunk_size", &self.min_chunk_size)
383 .field("flags", &self.flags)
384 .finish()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_max_size() {
394 let codec = Codec::new();
395 codec.set_max_inbound_size(5);
396 let mut buf = BytesMut::new();
397 buf.extend_from_slice(b"\0\x09");
398 assert_eq!(
399 codec.decode(&mut buf).err(),
400 Some(DecodeError::MaxSizeExceeded { size: 9, max_size: 5 })
401 );
402 }
403}