1use std::cell::Cell;
2
3use bytes::{Buf, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5
6use super::{decode::decode_packet, encode::EncodeLtd, Packet};
7use crate::error::{DecodeError, EncodeError};
8use crate::types::{FixedHeader, MAX_PACKET_SIZE};
9use crate::utils::decode_variable_length;
10
11#[derive(Debug, Clone)]
12pub struct Codec {
13 state: Cell<DecodeState>,
14 max_in_size: Cell<u32>,
15 max_out_size: Cell<u32>,
16 flags: Cell<CodecFlags>,
17}
18
19bitflags::bitflags! {
20 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
21 pub struct CodecFlags: u8 {
22 const NO_PROBLEM_INFO = 0b0000_0001;
23 const NO_RETAIN = 0b0000_0010;
24 const NO_SUB_IDS = 0b0000_1000;
25 }
26}
27
28#[derive(Debug, Clone, Copy)]
29enum DecodeState {
30 FrameHeader,
31 Frame(FixedHeader),
32}
33
34impl Codec {
35 pub fn new(max_in_size: u32, max_out_size: u32) -> Self {
37 Codec {
38 state: Cell::new(DecodeState::FrameHeader),
39 max_in_size: Cell::new(max_in_size),
40 max_out_size: Cell::new(max_out_size),
41 flags: Cell::new(CodecFlags::empty()),
42 }
43 }
44
45 pub fn max_inbound_size(&self) -> u32 {
50 self.max_in_size.get()
51 }
52
53 pub fn max_outbound_size(&self) -> u32 {
58 self.max_out_size.get()
59 }
60
61 pub fn set_max_inbound_size(&mut self, size: u32) {
66 self.max_in_size.set(size);
67 }
68
69 pub fn set_max_outbound_size(&mut self, mut size: u32) {
74 if size > 5 {
75 size -= 5;
77 }
78 self.max_out_size.set(size);
79 }
80
81 #[inline]
82 #[allow(dead_code)]
83 pub(crate) fn retain_available(&self) -> bool {
84 !self.flags.get().contains(CodecFlags::NO_RETAIN)
85 }
86
87 #[inline]
88 #[allow(dead_code)]
89 pub(crate) fn sub_ids_available(&self) -> bool {
90 !self.flags.get().contains(CodecFlags::NO_SUB_IDS)
91 }
92
93 #[inline]
94 #[allow(dead_code)]
95 pub(crate) fn set_retain_available(&self, val: bool) {
96 let mut flags = self.flags.get();
97 flags.set(CodecFlags::NO_RETAIN, !val);
98 self.flags.set(flags);
99 }
100
101 #[inline]
102 #[allow(dead_code)]
103 pub(crate) fn set_sub_ids_available(&self, val: bool) {
104 let mut flags = self.flags.get();
105 flags.set(CodecFlags::NO_SUB_IDS, !val);
106 self.flags.set(flags);
107 }
108}
109
110impl Default for Codec {
111 fn default() -> Self {
112 Self::new(0, 0)
113 }
114}
115
116impl Decoder for Codec {
117 type Item = (Packet, u32);
118 type Error = DecodeError;
119
120 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
121 loop {
122 match self.state.get() {
123 DecodeState::FrameHeader => {
124 if src.len() < 2 {
125 return Ok(None);
126 }
127 let src_slice = src.as_ref();
128 let first_byte = src_slice[0];
129 match decode_variable_length(&src_slice[1..])? {
130 Some((remaining_length, consumed)) => {
131 let max_in_size = self.max_in_size.get();
133 if max_in_size != 0 && max_in_size < remaining_length {
134 log::debug!(
135 "MaxSizeExceeded max-size: {max_in_size}, remaining: {remaining_length}"
136 );
137 return Err(DecodeError::MaxSizeExceeded);
138 }
139 src.advance(consumed + 1);
140 self.state.set(DecodeState::Frame(FixedHeader { first_byte, remaining_length }));
141 let remaining_length = remaining_length as usize;
143 if src.len() < remaining_length {
144 src.reserve(remaining_length); return Ok(None);
147 }
148 }
149 None => {
150 return Ok(None);
151 }
152 }
153 }
154 DecodeState::Frame(fixed) => {
155 if src.len() < fixed.remaining_length as usize {
156 return Ok(None);
157 }
158 let packet_buf = src.split_to(fixed.remaining_length as usize).freeze();
159 let packet = decode_packet(packet_buf, fixed.first_byte)?;
160 self.state.set(DecodeState::FrameHeader);
161 src.reserve(5); if let Packet::Connect(ref pkt) = packet {
164 let mut flags = self.flags.get();
165 flags.set(CodecFlags::NO_PROBLEM_INFO, !pkt.request_problem_info);
166 self.flags.set(flags);
167 }
168 return Ok(Some((packet, fixed.remaining_length)));
169 }
170 }
171 }
172 }
173}
174
175impl Encoder<Packet> for Codec {
176 type Error = EncodeError;
178
179 fn encode(&mut self, mut item: Packet, dst: &mut BytesMut) -> Result<(), EncodeError> {
180 if self.flags.get().contains(CodecFlags::NO_PROBLEM_INFO) {
182 match item {
183 Packet::PublishAck(ref mut pkt) | Packet::PublishReceived(ref mut pkt) => {
184 pkt.properties.clear();
185 let _ = pkt.reason_string.take();
186 }
187 Packet::PublishRelease(ref mut pkt) | Packet::PublishComplete(ref mut pkt) => {
188 pkt.properties.clear();
189 let _ = pkt.reason_string.take();
190 }
191 Packet::Subscribe(ref mut pkt) => {
192 pkt.user_properties.clear();
193 }
194 Packet::SubscribeAck(ref mut pkt) => {
195 pkt.properties.clear();
196 let _ = pkt.reason_string.take();
197 }
198 Packet::Unsubscribe(ref mut pkt) => {
199 pkt.user_properties.clear();
200 }
201 Packet::UnsubscribeAck(ref mut pkt) => {
202 pkt.properties.clear();
203 let _ = pkt.reason_string.take();
204 }
205 Packet::Auth(ref mut pkt) => {
206 pkt.user_properties.clear();
207 let _ = pkt.reason_string.take();
208 }
209 _ => (),
210 }
211 }
212
213 let max_out_size = self.max_out_size.get();
214 let max_size = if max_out_size != 0 { max_out_size } else { MAX_PACKET_SIZE };
215 let content_size = item.encoded_size(max_size);
216 if content_size > max_size as usize {
217 return Err(EncodeError::OverMaxPacketSize);
218 }
219 dst.reserve(content_size + 5);
220 item.encode(dst, content_size as u32)?; Ok(())
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_max_size() {
231 let mut codec = Codec::new(5, 5);
232 let mut buf = BytesMut::new();
233 buf.extend_from_slice(b"\0\x09");
234 assert_eq!(codec.decode(&mut buf).map_err(|e| matches!(e, DecodeError::MaxSizeExceeded)), Err(true));
235 }
236}