rmqtt_codec/v5/
codec.rs

1use std::cell::Cell;
2
3use bytes::{Buf, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5
6use super::{decode::decode_packet, encode::EncodeLtd, Packet};
7use crate::error::{DecodeError, EncodeError};
8use crate::types::{FixedHeader, MAX_PACKET_SIZE};
9use crate::utils::decode_variable_length;
10
11#[derive(Debug, Clone)]
12pub struct Codec {
13    state: Cell<DecodeState>,
14    max_in_size: Cell<u32>,
15    max_out_size: Cell<u32>,
16    flags: Cell<CodecFlags>,
17}
18
19bitflags::bitflags! {
20    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
21    pub struct CodecFlags: u8 {
22        const NO_PROBLEM_INFO = 0b0000_0001;
23        const NO_RETAIN       = 0b0000_0010;
24        const NO_SUB_IDS      = 0b0000_1000;
25    }
26}
27
28#[derive(Debug, Clone, Copy)]
29enum DecodeState {
30    FrameHeader,
31    Frame(FixedHeader),
32}
33
34impl Codec {
35    /// Create `Codec` instance
36    pub fn new(max_in_size: u32, max_out_size: u32) -> Self {
37        Codec {
38            state: Cell::new(DecodeState::FrameHeader),
39            max_in_size: Cell::new(max_in_size),
40            max_out_size: Cell::new(max_out_size),
41            flags: Cell::new(CodecFlags::empty()),
42        }
43    }
44
45    /// Set max inbound frame size.
46    ///
47    /// If max size is set to `0`, size is unlimited.
48    /// By default max size is set to `0`
49    pub fn max_inbound_size(&self) -> u32 {
50        self.max_in_size.get()
51    }
52
53    /// Set max outbound frame size.
54    ///
55    /// If max size is set to `0`, size is unlimited.
56    /// By default max size is set to `0`
57    pub fn max_outbound_size(&self) -> u32 {
58        self.max_out_size.get()
59    }
60
61    /// Set max inbound frame size.
62    ///
63    /// If max size is set to `0`, size is unlimited.
64    /// By default max size is set to `0`
65    pub fn set_max_inbound_size(&mut self, size: u32) {
66        self.max_in_size.set(size);
67    }
68
69    /// Set max outbound frame size.
70    ///
71    /// If max size is set to `0`, size is unlimited.
72    /// By default max size is set to `0`
73    pub fn set_max_outbound_size(&mut self, mut size: u32) {
74        if size > 5 {
75            // fixed header = 1, var_len(remaining.max_value()) = 4
76            size -= 5;
77        }
78        self.max_out_size.set(size);
79    }
80
81    #[inline]
82    #[allow(dead_code)]
83    pub(crate) fn retain_available(&self) -> bool {
84        !self.flags.get().contains(CodecFlags::NO_RETAIN)
85    }
86
87    #[inline]
88    #[allow(dead_code)]
89    pub(crate) fn sub_ids_available(&self) -> bool {
90        !self.flags.get().contains(CodecFlags::NO_SUB_IDS)
91    }
92
93    #[inline]
94    #[allow(dead_code)]
95    pub(crate) fn set_retain_available(&self, val: bool) {
96        let mut flags = self.flags.get();
97        flags.set(CodecFlags::NO_RETAIN, !val);
98        self.flags.set(flags);
99    }
100
101    #[inline]
102    #[allow(dead_code)]
103    pub(crate) fn set_sub_ids_available(&self, val: bool) {
104        let mut flags = self.flags.get();
105        flags.set(CodecFlags::NO_SUB_IDS, !val);
106        self.flags.set(flags);
107    }
108}
109
110impl Default for Codec {
111    fn default() -> Self {
112        Self::new(0, 0)
113    }
114}
115
116impl Decoder for Codec {
117    type Item = (Packet, u32);
118    type Error = DecodeError;
119
120    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
121        loop {
122            match self.state.get() {
123                DecodeState::FrameHeader => {
124                    if src.len() < 2 {
125                        return Ok(None);
126                    }
127                    let src_slice = src.as_ref();
128                    let first_byte = src_slice[0];
129                    match decode_variable_length(&src_slice[1..])? {
130                        Some((remaining_length, consumed)) => {
131                            // check max message size
132                            let max_in_size = self.max_in_size.get();
133                            if max_in_size != 0 && max_in_size < remaining_length {
134                                log::debug!(
135                                    "MaxSizeExceeded max-size: {max_in_size}, remaining: {remaining_length}"
136                                );
137                                return Err(DecodeError::MaxSizeExceeded);
138                            }
139                            src.advance(consumed + 1);
140                            self.state.set(DecodeState::Frame(FixedHeader { first_byte, remaining_length }));
141                            // todo: validate remaining_length against max frame size config
142                            let remaining_length = remaining_length as usize;
143                            if src.len() < remaining_length {
144                                // todo: subtract?
145                                src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager?
146                                return Ok(None);
147                            }
148                        }
149                        None => {
150                            return Ok(None);
151                        }
152                    }
153                }
154                DecodeState::Frame(fixed) => {
155                    if src.len() < fixed.remaining_length as usize {
156                        return Ok(None);
157                    }
158                    let packet_buf = src.split_to(fixed.remaining_length as usize).freeze();
159                    let packet = decode_packet(packet_buf, fixed.first_byte)?;
160                    self.state.set(DecodeState::FrameHeader);
161                    src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
162
163                    if let Packet::Connect(ref pkt) = packet {
164                        let mut flags = self.flags.get();
165                        flags.set(CodecFlags::NO_PROBLEM_INFO, !pkt.request_problem_info);
166                        self.flags.set(flags);
167                    }
168                    return Ok(Some((packet, fixed.remaining_length)));
169                }
170            }
171        }
172    }
173}
174
175impl Encoder<Packet> for Codec {
176    // type Item = Packet;
177    type Error = EncodeError;
178
179    fn encode(&mut self, mut item: Packet, dst: &mut BytesMut) -> Result<(), EncodeError> {
180        // handle [MQTT 3.1.2.11.7]
181        if self.flags.get().contains(CodecFlags::NO_PROBLEM_INFO) {
182            match item {
183                Packet::PublishAck(ref mut pkt) | Packet::PublishReceived(ref mut pkt) => {
184                    pkt.properties.clear();
185                    let _ = pkt.reason_string.take();
186                }
187                Packet::PublishRelease(ref mut pkt) | Packet::PublishComplete(ref mut pkt) => {
188                    pkt.properties.clear();
189                    let _ = pkt.reason_string.take();
190                }
191                Packet::Subscribe(ref mut pkt) => {
192                    pkt.user_properties.clear();
193                }
194                Packet::SubscribeAck(ref mut pkt) => {
195                    pkt.properties.clear();
196                    let _ = pkt.reason_string.take();
197                }
198                Packet::Unsubscribe(ref mut pkt) => {
199                    pkt.user_properties.clear();
200                }
201                Packet::UnsubscribeAck(ref mut pkt) => {
202                    pkt.properties.clear();
203                    let _ = pkt.reason_string.take();
204                }
205                Packet::Auth(ref mut pkt) => {
206                    pkt.user_properties.clear();
207                    let _ = pkt.reason_string.take();
208                }
209                _ => (),
210            }
211        }
212
213        let max_out_size = self.max_out_size.get();
214        let max_size = if max_out_size != 0 { max_out_size } else { MAX_PACKET_SIZE };
215        let content_size = item.encoded_size(max_size);
216        if content_size > max_size as usize {
217            return Err(EncodeError::OverMaxPacketSize);
218        }
219        dst.reserve(content_size + 5);
220        item.encode(dst, content_size as u32)?; // safe: max_size <= u32 max value
221        Ok(())
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_max_size() {
231        let mut codec = Codec::new(5, 5);
232        let mut buf = BytesMut::new();
233        buf.extend_from_slice(b"\0\x09");
234        assert_eq!(codec.decode(&mut buf).map_err(|e| matches!(e, DecodeError::MaxSizeExceeded)), Err(true));
235    }
236}