rmqtt_codec/v3/
codec.rs

1use std::cell::Cell;
2
3use bytes::{Buf, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5
6use super::{decode, encode, Packet};
7use crate::error::{DecodeError, EncodeError};
8use crate::types::{FixedHeader, QoS};
9use crate::utils::decode_variable_length;
10
11#[derive(Debug, Clone)]
12/// Mqtt v3.1.1 protocol codec
13pub struct Codec {
14    state: Cell<DecodeState>,
15    max_size: Cell<u32>,
16}
17
18#[derive(Debug, Copy, Clone, PartialEq, Eq)]
19enum DecodeState {
20    FrameHeader,
21    Frame(FixedHeader),
22}
23
24impl Codec {
25    /// Create `Codec` instance
26    pub fn new(max_packet_size: u32) -> Self {
27        Codec { state: Cell::new(DecodeState::FrameHeader), max_size: Cell::new(max_packet_size) }
28    }
29
30    /// Set max inbound frame size.
31    ///
32    /// If max size is set to `0`, size is unlimited.
33    /// By default max size is set to `0`
34    pub fn set_max_size(&mut self, size: u32) {
35        self.max_size.set(size);
36    }
37}
38
39impl Default for Codec {
40    fn default() -> Self {
41        Self::new(0)
42    }
43}
44
45impl Decoder for Codec {
46    type Item = (Packet, u32);
47    type Error = DecodeError;
48
49    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
50        loop {
51            match self.state.get() {
52                DecodeState::FrameHeader => {
53                    if src.len() < 2 {
54                        return Ok(None);
55                    }
56                    let src_slice = src.as_ref();
57                    let first_byte = src_slice[0];
58                    match decode_variable_length(&src_slice[1..])? {
59                        Some((remaining_length, consumed)) => {
60                            // check max message size
61                            let max_size = self.max_size.get();
62                            if max_size != 0 && max_size < remaining_length {
63                                return Err(DecodeError::MaxSizeExceeded);
64                            }
65                            src.advance(consumed + 1);
66                            self.state.set(DecodeState::Frame(FixedHeader { first_byte, remaining_length }));
67                            // todo: validate remaining_length against max frame size config
68                            let remaining_length = remaining_length as usize;
69                            if src.len() < remaining_length {
70                                // todo: subtract?
71                                src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager?
72                                return Ok(None);
73                            }
74                        }
75                        None => {
76                            return Ok(None);
77                        }
78                    }
79                }
80                DecodeState::Frame(fixed) => {
81                    if src.len() < fixed.remaining_length as usize {
82                        return Ok(None);
83                    }
84                    let packet_buf = src.split_to(fixed.remaining_length as usize);
85                    let packet = decode::decode_packet(packet_buf.freeze(), fixed.first_byte)?;
86                    self.state.set(DecodeState::FrameHeader);
87                    src.reserve(2);
88                    return Ok(Some((packet, fixed.remaining_length)));
89                }
90            }
91        }
92    }
93}
94
95impl Encoder<Packet> for Codec {
96    // type Item = Packet;
97    type Error = EncodeError;
98
99    fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), EncodeError> {
100        if let Packet::Publish(ref publish) = item {
101            if (publish.qos == QoS::AtLeastOnce || publish.qos == QoS::ExactlyOnce)
102                && publish.packet_id.is_none()
103            {
104                return Err(EncodeError::PacketIdRequired);
105            }
106        }
107        let content_size = encode::get_encoded_size(&item);
108        dst.reserve(content_size + 5);
109        encode::encode(&item, dst, content_size as u32)?;
110        Ok(())
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::v3::packet::Publish;
118    use bytes::Bytes;
119    use bytestring::ByteString;
120    use rmqtt_utils::timestamp_millis;
121
122    #[test]
123    fn test_max_size() {
124        let mut codec = Codec::default();
125        codec.set_max_size(5);
126
127        let mut buf = BytesMut::new();
128        buf.extend_from_slice(b"\0\x09");
129        assert_eq!(codec.decode(&mut buf).map_err(|e| matches!(e, DecodeError::MaxSizeExceeded)), Err(true));
130    }
131
132    #[test]
133    fn test_packet() {
134        let mut codec = Codec::default();
135        let mut buf = BytesMut::new();
136
137        let mut pkt = Box::new(Publish {
138            dup: false,
139            retain: false,
140            qos: QoS::AtMostOnce,
141            topic: ByteString::from_static("/test"),
142            packet_id: None,
143            payload: Bytes::from(Vec::from("a".repeat(260 * 1024))),
144            properties: None,
145            delay_interval: None,
146            create_time: None,
147        });
148        codec.encode(Packet::Publish(pkt.clone()), &mut buf).unwrap();
149
150        let pkt2 =
151            if let (Packet::Publish(v), _) = codec.decode(&mut buf).unwrap().unwrap() { v } else { panic!() };
152        pkt.create_time = Some(timestamp_millis());
153        assert_eq!(pkt, pkt2);
154    }
155}