rmqtt_codec/v5/
codec.rs

1use std::cell::Cell;
2
3use bytes::{Buf, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5
6use super::{decode::decode_packet, encode::EncodeLtd, Packet};
7use crate::error::{DecodeError, EncodeError};
8use crate::types::{FixedHeader, MAX_PACKET_SIZE};
9use crate::utils::decode_variable_length;
10
11#[derive(Debug, Clone)]
12pub struct Codec {
13    state: Cell<DecodeState>,
14    max_in_size: Cell<u32>,
15    max_out_size: Cell<u32>,
16    flags: Cell<CodecFlags>,
17}
18
19bitflags::bitflags! {
20    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
21    pub struct CodecFlags: u8 {
22        const NO_PROBLEM_INFO = 0b0000_0001;
23        const NO_RETAIN       = 0b0000_0010;
24        const NO_SUB_IDS      = 0b0000_1000;
25    }
26}
27
28#[derive(Debug, Clone, Copy)]
29enum DecodeState {
30    FrameHeader,
31    Frame(FixedHeader),
32}
33
34impl Codec {
35    /// Create `Codec` instance
36    pub fn new(max_in_size: u32, max_out_size: u32) -> Self {
37        Codec {
38            state: Cell::new(DecodeState::FrameHeader),
39            max_in_size: Cell::new(max_in_size),
40            max_out_size: Cell::new(max_out_size),
41            flags: Cell::new(CodecFlags::empty()),
42        }
43    }
44
45    /// Set max inbound frame size.
46    ///
47    /// If max size is set to `0`, size is unlimited.
48    /// By default max size is set to `0`
49    pub fn max_inbound_size(&self) -> u32 {
50        self.max_in_size.get()
51    }
52
53    /// Set max outbound frame size.
54    ///
55    /// If max size is set to `0`, size is unlimited.
56    /// By default max size is set to `0`
57    pub fn max_outbound_size(&self) -> u32 {
58        self.max_out_size.get()
59    }
60
61    /// Set max inbound frame size.
62    ///
63    /// If max size is set to `0`, size is unlimited.
64    /// By default max size is set to `0`
65    pub fn set_max_inbound_size(&mut self, size: u32) {
66        self.max_in_size.set(size);
67    }
68
69    /// Set max outbound frame size.
70    ///
71    /// If max size is set to `0`, size is unlimited.
72    /// By default max size is set to `0`
73    pub fn set_max_outbound_size(&mut self, mut size: u32) {
74        if size > 5 {
75            // fixed header = 1, var_len(remaining.max_value()) = 4
76            size -= 5;
77        }
78        self.max_out_size.set(size);
79    }
80
81    #[inline]
82    #[allow(dead_code)]
83    pub(crate) fn retain_available(&self) -> bool {
84        !self.flags.get().contains(CodecFlags::NO_RETAIN)
85    }
86
87    #[inline]
88    #[allow(dead_code)]
89    pub(crate) fn sub_ids_available(&self) -> bool {
90        !self.flags.get().contains(CodecFlags::NO_SUB_IDS)
91    }
92
93    #[inline]
94    #[allow(dead_code)]
95    pub(crate) fn set_retain_available(&self, val: bool) {
96        let mut flags = self.flags.get();
97        flags.set(CodecFlags::NO_RETAIN, !val);
98        self.flags.set(flags);
99    }
100
101    #[inline]
102    #[allow(dead_code)]
103    pub(crate) fn set_sub_ids_available(&self, val: bool) {
104        let mut flags = self.flags.get();
105        flags.set(CodecFlags::NO_SUB_IDS, !val);
106        self.flags.set(flags);
107    }
108}
109
110impl Default for Codec {
111    fn default() -> Self {
112        Self::new(0, 0)
113    }
114}
115
116impl Decoder for Codec {
117    type Item = (Packet, u32);
118    type Error = DecodeError;
119
120    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
121        loop {
122            match self.state.get() {
123                DecodeState::FrameHeader => {
124                    if src.len() < 2 {
125                        return Ok(None);
126                    }
127                    let src_slice = src.as_ref();
128                    let first_byte = src_slice[0];
129                    match decode_variable_length(&src_slice[1..])? {
130                        Some((remaining_length, consumed)) => {
131                            // check max message size
132                            let max_in_size = self.max_in_size.get();
133                            if max_in_size != 0 && max_in_size < remaining_length {
134                                log::debug!(
135                                    "MaxSizeExceeded max-size: {}, remaining: {}",
136                                    max_in_size,
137                                    remaining_length
138                                );
139                                return Err(DecodeError::MaxSizeExceeded);
140                            }
141                            src.advance(consumed + 1);
142                            self.state.set(DecodeState::Frame(FixedHeader { first_byte, remaining_length }));
143                            // todo: validate remaining_length against max frame size config
144                            let remaining_length = remaining_length as usize;
145                            if src.len() < remaining_length {
146                                // todo: subtract?
147                                src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager?
148                                return Ok(None);
149                            }
150                        }
151                        None => {
152                            return Ok(None);
153                        }
154                    }
155                }
156                DecodeState::Frame(fixed) => {
157                    if src.len() < fixed.remaining_length as usize {
158                        return Ok(None);
159                    }
160                    let packet_buf = src.split_to(fixed.remaining_length as usize).freeze();
161                    let packet = decode_packet(packet_buf, fixed.first_byte)?;
162                    self.state.set(DecodeState::FrameHeader);
163                    src.reserve(5); // enough to fix 1 fixed header byte + 4 bytes max variable packet length
164
165                    if let Packet::Connect(ref pkt) = packet {
166                        let mut flags = self.flags.get();
167                        flags.set(CodecFlags::NO_PROBLEM_INFO, !pkt.request_problem_info);
168                        self.flags.set(flags);
169                    }
170                    return Ok(Some((packet, fixed.remaining_length)));
171                }
172            }
173        }
174    }
175}
176
177impl Encoder<Packet> for Codec {
178    // type Item = Packet;
179    type Error = EncodeError;
180
181    fn encode(&mut self, mut item: Packet, dst: &mut BytesMut) -> Result<(), EncodeError> {
182        // handle [MQTT 3.1.2.11.7]
183        if self.flags.get().contains(CodecFlags::NO_PROBLEM_INFO) {
184            match item {
185                Packet::PublishAck(ref mut pkt) | Packet::PublishReceived(ref mut pkt) => {
186                    pkt.properties.clear();
187                    let _ = pkt.reason_string.take();
188                }
189                Packet::PublishRelease(ref mut pkt) | Packet::PublishComplete(ref mut pkt) => {
190                    pkt.properties.clear();
191                    let _ = pkt.reason_string.take();
192                }
193                Packet::Subscribe(ref mut pkt) => {
194                    pkt.user_properties.clear();
195                }
196                Packet::SubscribeAck(ref mut pkt) => {
197                    pkt.properties.clear();
198                    let _ = pkt.reason_string.take();
199                }
200                Packet::Unsubscribe(ref mut pkt) => {
201                    pkt.user_properties.clear();
202                }
203                Packet::UnsubscribeAck(ref mut pkt) => {
204                    pkt.properties.clear();
205                    let _ = pkt.reason_string.take();
206                }
207                Packet::Auth(ref mut pkt) => {
208                    pkt.user_properties.clear();
209                    let _ = pkt.reason_string.take();
210                }
211                _ => (),
212            }
213        }
214
215        let max_out_size = self.max_out_size.get();
216        let max_size = if max_out_size != 0 { max_out_size } else { MAX_PACKET_SIZE };
217        let content_size = item.encoded_size(max_size);
218        if content_size > max_size as usize {
219            return Err(EncodeError::OverMaxPacketSize);
220        }
221        dst.reserve(content_size + 5);
222        item.encode(dst, content_size as u32)?; // safe: max_size <= u32 max value
223        Ok(())
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_max_size() {
233        let mut codec = Codec::new(5, 5);
234        let mut buf = BytesMut::new();
235        buf.extend_from_slice(b"\0\x09");
236        assert_eq!(codec.decode(&mut buf).map_err(|e| matches!(e, DecodeError::MaxSizeExceeded)), Err(true));
237    }
238}