1use std::{cell::Cell, cmp::min, fmt, num::NonZeroU32};
2
3use ntex_bytes::{Buf, BufMut, Bytes, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5
6use crate::error::{DecodeError, EncodeError};
7use crate::types::{packet_type, FixedHeader, MAX_PACKET_SIZE};
8use crate::{payload::Payload, utils, utils::decode_variable_length};
9
10use super::{decode::decode_packet, encode::EncodeLtd, packet::Publish, Packet};
11use super::{Decoded, Encoded};
12
13pub struct Codec {
14 state: Cell<DecodeState>,
15 max_in_size: Cell<u32>,
16 max_out_size: Cell<u32>,
17 min_chunk_size: Cell<u32>,
18 flags: Cell<CodecFlags>,
19 encoding_payload: Cell<Option<NonZeroU32>>,
20}
21
22bitflags::bitflags! {
23 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
24 pub struct CodecFlags: u8 {
25 const NO_PROBLEM_INFO = 0b0000_0001;
26 const NO_RETAIN = 0b0000_0010;
27 const NO_SUB_IDS = 0b0000_1000;
28 }
29}
30
31#[derive(Debug, Clone, Copy)]
32enum DecodeState {
33 FrameHeader,
34 Frame(FixedHeader),
35 PublishHeader(FixedHeader),
36 PublishProperties(u32, FixedHeader),
37 PublishPayload(u32),
38}
39
40impl Codec {
41 pub fn new() -> Self {
43 Codec {
44 state: Cell::new(DecodeState::FrameHeader),
45 max_in_size: Cell::new(0),
46 max_out_size: Cell::new(0),
47 min_chunk_size: Cell::new(0),
48 flags: Cell::new(CodecFlags::empty()),
49 encoding_payload: Cell::new(None),
50 }
51 }
52
53 pub fn set_min_chunk_size(&self, size: u32) {
60 self.min_chunk_size.set(size)
61 }
62
63 pub fn max_inbound_size(&self) -> u32 {
68 self.max_in_size.get()
69 }
70
71 pub fn max_outbound_size(&self) -> u32 {
76 self.max_out_size.get()
77 }
78
79 pub fn set_max_inbound_size(&self, size: u32) {
84 self.max_in_size.set(size);
85 }
86
87 pub fn set_max_outbound_size(&self, mut size: u32) {
92 if size > 5 {
93 size -= 5;
95 }
96 self.max_out_size.set(size);
97 }
98
99 pub(crate) fn retain_available(&self) -> bool {
100 !self.flags.get().contains(CodecFlags::NO_RETAIN)
101 }
102
103 pub(crate) fn sub_ids_available(&self) -> bool {
104 !self.flags.get().contains(CodecFlags::NO_SUB_IDS)
105 }
106
107 pub(crate) fn set_retain_available(&self, val: bool) {
108 let mut flags = self.flags.get();
109 flags.set(CodecFlags::NO_RETAIN, !val);
110 self.flags.set(flags);
111 }
112
113 pub(crate) fn set_sub_ids_available(&self, val: bool) {
114 let mut flags = self.flags.get();
115 flags.set(CodecFlags::NO_SUB_IDS, !val);
116 self.flags.set(flags);
117 }
118}
119
120impl Default for Codec {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl Decoder for Codec {
127 type Item = super::Decoded;
128 type Error = DecodeError;
129
130 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
131 loop {
132 match self.state.get() {
133 DecodeState::FrameHeader => {
134 if src.len() < 2 {
135 return Ok(None);
136 }
137 let src_slice = src.as_ref();
138 let first_byte = src_slice[0];
139 match decode_variable_length(&src_slice[1..])? {
140 Some((remaining_length, consumed)) => {
141 let max_in_size = self.max_in_size.get();
143 if max_in_size != 0 && max_in_size < remaining_length {
144 log::debug!(
145 "MaxSizeExceeded max-size: {}, remaining: {}",
146 max_in_size,
147 remaining_length
148 );
149 return Err(DecodeError::MaxSizeExceeded);
150 }
151 src.advance(consumed + 1);
152
153 if packet_type::is_publish(first_byte) {
154 self.state.set(DecodeState::PublishHeader(FixedHeader {
155 first_byte,
156 remaining_length,
157 }));
158 } else {
159 self.state.set(DecodeState::Frame(FixedHeader {
160 first_byte,
161 remaining_length,
162 }));
163
164 let remaining_length = remaining_length as usize;
166 if src.len() < remaining_length {
167 src.reserve(remaining_length); return Ok(None);
170 }
171 }
172 }
173 None => {
174 return Ok(None);
175 }
176 }
177 }
178 DecodeState::PublishHeader(fixed) => {
179 if let Some(len) = Publish::packet_header_size(src, fixed.first_byte)? {
180 self.state.set(DecodeState::PublishProperties(len, fixed));
181 } else {
182 return Ok(None);
183 }
184 }
185 DecodeState::PublishProperties(props_len, fixed) => {
186 if src.len() < props_len as usize {
187 return Ok(None);
188 }
189 let payload_len = (fixed.remaining_length - props_len);
190 let mut buf = src.split_to(props_len as usize).freeze();
191 let publish = Publish::decode(&mut buf, fixed.first_byte, payload_len)?;
192
193 let len = src.len() as u32;
194 let min_chunk_size = self.min_chunk_size.get();
195 if len >= payload_len || min_chunk_size == 0 || len >= min_chunk_size {
196 let payload =
197 src.split_to(min(src.len(), payload_len as usize)).freeze();
198 let remaining = payload_len - payload.len() as u32;
199
200 if remaining > 0 {
201 self.state.set(DecodeState::PublishPayload(remaining));
202 } else {
203 self.state.set(DecodeState::FrameHeader);
204 src.reserve(5); }
206
207 return Ok(Some(Decoded::Publish(
208 publish,
209 payload,
210 fixed.remaining_length,
211 )));
212 } else {
213 self.state.set(DecodeState::PublishPayload(payload_len));
214 return Ok(Some(Decoded::Publish(
215 publish,
216 Bytes::new(),
217 fixed.remaining_length,
218 )));
219 }
220 }
221 DecodeState::PublishPayload(remaining) => {
222 let len = src.len() as u32;
223 let min_chunk_size = self.min_chunk_size.get();
224
225 return if (len >= remaining)
226 || (min_chunk_size != 0 && len >= min_chunk_size)
227 {
228 let payload = src.split_to(min(src.len(), remaining as usize)).freeze();
229 let remaining = remaining - payload.len() as u32;
230
231 let eof = if remaining > 0 {
232 self.state.set(DecodeState::PublishPayload(remaining));
233 false
234 } else {
235 self.state.set(DecodeState::FrameHeader);
236 src.reserve(5); true
238 };
239 Ok(Some(Decoded::PayloadChunk(payload, eof)))
240 } else {
241 Ok(None)
242 };
243 }
244 DecodeState::Frame(fixed) => {
245 return if src.len() < fixed.remaining_length as usize {
246 Ok(None)
247 } else {
248 let packet_buf = src.split_to(fixed.remaining_length as usize).freeze();
249 let packet = decode_packet(packet_buf, fixed.first_byte)?;
250 self.state.set(DecodeState::FrameHeader);
251 src.reserve(5); if let Packet::Connect(ref pkt) = packet {
254 let mut flags = self.flags.get();
255 flags.set(CodecFlags::NO_PROBLEM_INFO, !pkt.request_problem_info);
256 self.flags.set(flags);
257 }
258 Ok(Some(Decoded::Packet(packet, fixed.remaining_length)))
259 };
260 }
261 }
262 }
263 }
264}
265
266impl Encoder for Codec {
267 type Item = Encoded;
268 type Error = EncodeError;
269
270 fn encode(&self, mut item: Self::Item, dst: &mut BytesMut) -> Result<(), EncodeError> {
271 if self.flags.get().contains(CodecFlags::NO_PROBLEM_INFO) {
273 match item {
274 Encoded::Packet(Packet::PublishAck(ref mut pkt))
275 | Encoded::Packet(Packet::PublishReceived(ref mut pkt)) => {
276 pkt.properties.clear();
277 let _ = pkt.reason_string.take();
278 }
279 Encoded::Packet(Packet::PublishRelease(ref mut pkt))
280 | Encoded::Packet(Packet::PublishComplete(ref mut pkt)) => {
281 pkt.properties.clear();
282 let _ = pkt.reason_string.take();
283 }
284 Encoded::Packet(Packet::Subscribe(ref mut pkt)) => {
285 pkt.user_properties.clear();
286 }
287 Encoded::Packet(Packet::SubscribeAck(ref mut pkt)) => {
288 pkt.properties.clear();
289 let _ = pkt.reason_string.take();
290 }
291 Encoded::Packet(Packet::Unsubscribe(ref mut pkt)) => {
292 pkt.user_properties.clear();
293 }
294 Encoded::Packet(Packet::UnsubscribeAck(ref mut pkt)) => {
295 pkt.properties.clear();
296 let _ = pkt.reason_string.take();
297 }
298 Encoded::Packet(Packet::Auth(ref mut pkt)) => {
299 pkt.user_properties.clear();
300 let _ = pkt.reason_string.take();
301 }
302 _ => (),
303 }
304 }
305
306 let max_out_size = self.max_out_size.get();
307 let max_size = if max_out_size != 0 { max_out_size } else { MAX_PACKET_SIZE };
308 match item {
309 Encoded::Packet(pkt) => {
310 if self.encoding_payload.get().is_some() {
311 log::trace!("Expect payload, received {:?}", pkt);
312 Err(EncodeError::ExpectPayload)
313 } else {
314 let content_size = pkt.encoded_size(max_size);
315 if content_size > max_size as usize {
316 Err(EncodeError::OverMaxPacketSize)
317 } else {
318 dst.reserve(content_size + 5);
319 pkt.encode(dst, content_size as u32)?; Ok(())
321 }
322 }
323 }
324 Encoded::Publish(pkt, buf) => {
325 let content_size = pkt.encoded_size(max_size) as u32;
326 if content_size > max_size {
327 return Err(EncodeError::OverMaxPacketSize);
328 }
329
330 let total_size = content_size - pkt.payload_size
331 + buf.as_ref().map(|b| b.len() as u32).unwrap_or(0);
332 dst.reserve((total_size + 5) as usize);
333 pkt.encode(dst, content_size)?; let remaining = if let Some(buf) = buf {
336 dst.extend_from_slice(&buf);
337 pkt.payload_size - buf.len() as u32
338 } else {
339 pkt.payload_size
340 };
341 self.encoding_payload.set(NonZeroU32::new(remaining as u32));
342 Ok(())
343 }
344 Encoded::PayloadChunk(chunk) => {
345 if let Some(remaining) = self.encoding_payload.get() {
346 let len = chunk.len() as u32;
347 if len > remaining.get() {
348 Err(EncodeError::OverPublishSize)
349 } else {
350 dst.extend_from_slice(&chunk);
351 self.encoding_payload.set(NonZeroU32::new(remaining.get() - len));
352 Ok(())
353 }
354 } else {
355 Err(EncodeError::UnexpectedPayload)
356 }
357 }
358 }
359 }
360}
361
362impl Clone for Codec {
363 fn clone(&self) -> Self {
364 Codec {
365 state: Cell::new(DecodeState::FrameHeader),
366 max_in_size: self.max_in_size.clone(),
367 max_out_size: self.max_out_size.clone(),
368 min_chunk_size: self.min_chunk_size.clone(),
369 flags: Cell::new(CodecFlags::empty()),
370 encoding_payload: Cell::new(None),
371 }
372 }
373}
374
375impl fmt::Debug for Codec {
376 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
377 f.debug_struct("Codec")
378 .field("state", &self.state)
379 .field("max_in_size", &self.max_in_size)
380 .field("max_out_size", &self.max_out_size)
381 .field("min_chunk_size", &self.min_chunk_size)
382 .field("flags", &self.flags)
383 .finish()
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_max_size() {
393 let codec = Codec::new();
394 codec.set_max_inbound_size(5);
395 let mut buf = BytesMut::new();
396 buf.extend_from_slice(b"\0\x09");
397 assert_eq!(codec.decode(&mut buf).err(), Some(DecodeError::MaxSizeExceeded));
398 }
399}