1use std::cell::Cell;
2
3use bytes::{Buf, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5
6use super::{decode::decode_packet, encode::EncodeLtd, Packet};
7use crate::error::{DecodeError, EncodeError};
8use crate::types::{FixedHeader, MAX_PACKET_SIZE};
9use crate::utils::decode_variable_length;
10
11#[derive(Debug, Clone)]
12pub struct Codec {
13 state: Cell<DecodeState>,
14 max_in_size: Cell<u32>,
15 max_out_size: Cell<u32>,
16 flags: Cell<CodecFlags>,
17}
18
19bitflags::bitflags! {
20 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
21 pub struct CodecFlags: u8 {
22 const NO_PROBLEM_INFO = 0b0000_0001;
23 const NO_RETAIN = 0b0000_0010;
24 const NO_SUB_IDS = 0b0000_1000;
25 }
26}
27
28#[derive(Debug, Clone, Copy)]
29enum DecodeState {
30 FrameHeader,
31 Frame(FixedHeader),
32}
33
34impl Codec {
35 pub fn new(max_in_size: u32, max_out_size: u32) -> Self {
37 Codec {
38 state: Cell::new(DecodeState::FrameHeader),
39 max_in_size: Cell::new(max_in_size),
40 max_out_size: Cell::new(max_out_size),
41 flags: Cell::new(CodecFlags::empty()),
42 }
43 }
44
45 pub fn max_inbound_size(&self) -> u32 {
50 self.max_in_size.get()
51 }
52
53 pub fn max_outbound_size(&self) -> u32 {
58 self.max_out_size.get()
59 }
60
61 pub fn set_max_inbound_size(&mut self, size: u32) {
66 self.max_in_size.set(size);
67 }
68
69 pub fn set_max_outbound_size(&mut self, mut size: u32) {
74 if size > 5 {
75 size -= 5;
77 }
78 self.max_out_size.set(size);
79 }
80
81 #[inline]
82 #[allow(dead_code)]
83 pub(crate) fn retain_available(&self) -> bool {
84 !self.flags.get().contains(CodecFlags::NO_RETAIN)
85 }
86
87 #[inline]
88 #[allow(dead_code)]
89 pub(crate) fn sub_ids_available(&self) -> bool {
90 !self.flags.get().contains(CodecFlags::NO_SUB_IDS)
91 }
92
93 #[inline]
94 #[allow(dead_code)]
95 pub(crate) fn set_retain_available(&self, val: bool) {
96 let mut flags = self.flags.get();
97 flags.set(CodecFlags::NO_RETAIN, !val);
98 self.flags.set(flags);
99 }
100
101 #[inline]
102 #[allow(dead_code)]
103 pub(crate) fn set_sub_ids_available(&self, val: bool) {
104 let mut flags = self.flags.get();
105 flags.set(CodecFlags::NO_SUB_IDS, !val);
106 self.flags.set(flags);
107 }
108}
109
110impl Default for Codec {
111 fn default() -> Self {
112 Self::new(0, 0)
113 }
114}
115
116impl Decoder for Codec {
117 type Item = (Packet, u32);
118 type Error = DecodeError;
119
120 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
121 loop {
122 match self.state.get() {
123 DecodeState::FrameHeader => {
124 if src.len() < 2 {
125 return Ok(None);
126 }
127 let src_slice = src.as_ref();
128 let first_byte = src_slice[0];
129 match decode_variable_length(&src_slice[1..])? {
130 Some((remaining_length, consumed)) => {
131 let max_in_size = self.max_in_size.get();
133 if max_in_size != 0 && max_in_size < remaining_length {
134 log::debug!(
135 "MaxSizeExceeded max-size: {}, remaining: {}",
136 max_in_size,
137 remaining_length
138 );
139 return Err(DecodeError::MaxSizeExceeded);
140 }
141 src.advance(consumed + 1);
142 self.state.set(DecodeState::Frame(FixedHeader { first_byte, remaining_length }));
143 let remaining_length = remaining_length as usize;
145 if src.len() < remaining_length {
146 src.reserve(remaining_length); return Ok(None);
149 }
150 }
151 None => {
152 return Ok(None);
153 }
154 }
155 }
156 DecodeState::Frame(fixed) => {
157 if src.len() < fixed.remaining_length as usize {
158 return Ok(None);
159 }
160 let packet_buf = src.split_to(fixed.remaining_length as usize).freeze();
161 let packet = decode_packet(packet_buf, fixed.first_byte)?;
162 self.state.set(DecodeState::FrameHeader);
163 src.reserve(5); if let Packet::Connect(ref pkt) = packet {
166 let mut flags = self.flags.get();
167 flags.set(CodecFlags::NO_PROBLEM_INFO, !pkt.request_problem_info);
168 self.flags.set(flags);
169 }
170 return Ok(Some((packet, fixed.remaining_length)));
171 }
172 }
173 }
174 }
175}
176
177impl Encoder<Packet> for Codec {
178 type Error = EncodeError;
180
181 fn encode(&mut self, mut item: Packet, dst: &mut BytesMut) -> Result<(), EncodeError> {
182 if self.flags.get().contains(CodecFlags::NO_PROBLEM_INFO) {
184 match item {
185 Packet::PublishAck(ref mut pkt) | Packet::PublishReceived(ref mut pkt) => {
186 pkt.properties.clear();
187 let _ = pkt.reason_string.take();
188 }
189 Packet::PublishRelease(ref mut pkt) | Packet::PublishComplete(ref mut pkt) => {
190 pkt.properties.clear();
191 let _ = pkt.reason_string.take();
192 }
193 Packet::Subscribe(ref mut pkt) => {
194 pkt.user_properties.clear();
195 }
196 Packet::SubscribeAck(ref mut pkt) => {
197 pkt.properties.clear();
198 let _ = pkt.reason_string.take();
199 }
200 Packet::Unsubscribe(ref mut pkt) => {
201 pkt.user_properties.clear();
202 }
203 Packet::UnsubscribeAck(ref mut pkt) => {
204 pkt.properties.clear();
205 let _ = pkt.reason_string.take();
206 }
207 Packet::Auth(ref mut pkt) => {
208 pkt.user_properties.clear();
209 let _ = pkt.reason_string.take();
210 }
211 _ => (),
212 }
213 }
214
215 let max_out_size = self.max_out_size.get();
216 let max_size = if max_out_size != 0 { max_out_size } else { MAX_PACKET_SIZE };
217 let content_size = item.encoded_size(max_size);
218 if content_size > max_size as usize {
219 return Err(EncodeError::OverMaxPacketSize);
220 }
221 dst.reserve(content_size + 5);
222 item.encode(dst, content_size as u32)?; Ok(())
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_max_size() {
233 let mut codec = Codec::new(5, 5);
234 let mut buf = BytesMut::new();
235 buf.extend_from_slice(b"\0\x09");
236 assert_eq!(codec.decode(&mut buf).map_err(|e| matches!(e, DecodeError::MaxSizeExceeded)), Err(true));
237 }
238}