mqtt_codec/codec/
mod.rs

1use std::io::Cursor;
2
3use actix_codec::{Decoder, Encoder};
4use bytes::BytesMut;
5
6use crate::error::ParseError;
7use crate::proto::QoS;
8use crate::{Packet, Publish};
9
10mod decode;
11mod encode;
12
13use self::decode::*;
14use self::encode::*;
15
16bitflags! {
17    pub struct ConnectFlags: u8 {
18        const USERNAME      = 0b1000_0000;
19        const PASSWORD      = 0b0100_0000;
20        const WILL_RETAIN   = 0b0010_0000;
21        const WILL_QOS      = 0b0001_1000;
22        const WILL          = 0b0000_0100;
23        const CLEAN_SESSION = 0b0000_0010;
24    }
25}
26
27pub const WILL_QOS_SHIFT: u8 = 3;
28
29bitflags! {
30    pub struct ConnectAckFlags: u8 {
31        const SESSION_PRESENT = 0b0000_0001;
32    }
33}
34
35#[derive(Debug)]
36pub struct Codec {
37    state: DecodeState,
38    max_size: usize,
39}
40
41#[derive(Debug, Clone, Copy)]
42enum DecodeState {
43    FrameHeader,
44    Frame(FixedHeader),
45}
46
47impl Codec {
48    /// Create `Codec` instance
49    pub fn new() -> Self {
50        Codec {
51            state: DecodeState::FrameHeader,
52            max_size: 0,
53        }
54    }
55
56    /// Set max inbound frame size.
57    ///
58    /// If max size is set to `0`, size is unlimited.
59    /// By default max size is set to `0`
60    pub fn max_size(mut self, size: usize) -> Self {
61        self.max_size = size;
62        self
63    }
64}
65
66impl Default for Codec {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl Decoder for Codec {
73    type Item = Packet;
74    type Error = ParseError;
75
76    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, ParseError> {
77        loop {
78            match self.state {
79                DecodeState::FrameHeader => {
80                    if src.len() < 2 {
81                        return Ok(None);
82                    }
83                    let fixed = src.as_ref()[0];
84                    match decode_variable_length(&src.as_ref()[1..])? {
85                        Some((remaining_length, consumed)) => {
86                            // check max message size
87                            if self.max_size != 0 && self.max_size < remaining_length {
88                                return Err(ParseError::MaxSizeExceeded);
89                            }
90                            src.split_to(consumed + 1);
91                            self.state = DecodeState::Frame(FixedHeader {
92                                packet_type: fixed >> 4,
93                                packet_flags: fixed & 0xF,
94                                remaining_length,
95                            });
96                            // todo: validate remaining_length against max frame size config
97                            if src.len() < remaining_length {
98                                // todo: subtract?
99                                src.reserve(remaining_length); // extend receiving buffer to fit the whole frame -- todo: too eager?
100                                return Ok(None);
101                            }
102                        }
103                        None => {
104                            return Ok(None);
105                        }
106                    }
107                }
108                DecodeState::Frame(fixed) => {
109                    if src.len() < fixed.remaining_length {
110                        return Ok(None);
111                    }
112                    let packet_buf = src.split_to(fixed.remaining_length);
113                    let mut packet_cur = Cursor::new(packet_buf.freeze());
114                    let packet = read_packet(&mut packet_cur, fixed)?;
115                    self.state = DecodeState::FrameHeader;
116                    src.reserve(2);
117                    return Ok(Some(packet));
118                }
119            }
120        }
121    }
122}
123
124impl Encoder for Codec {
125    type Item = Packet;
126    type Error = ParseError;
127
128    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), ParseError> {
129        if let Packet::Publish(Publish { qos, packet_id, .. }) = item {
130            if (qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce) && packet_id.is_none() {
131                return Err(ParseError::PacketIdRequired);
132            }
133        }
134        let content_size = get_encoded_size(&item);
135        dst.reserve(content_size + 5);
136        write_packet(&item, dst, content_size);
137        Ok(())
138    }
139}
140
141#[derive(Debug, PartialEq, Clone, Copy)]
142pub(crate) struct FixedHeader {
143    /// MQTT Control Packet type
144    pub packet_type: u8,
145    /// Flags specific to each MQTT Control Packet type
146    pub packet_flags: u8,
147    /// the number of bytes remaining within the current packet,
148    /// including data in the variable header and the payload.
149    pub remaining_length: usize,
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_max_size() {
158        let mut codec = Codec::new().max_size(5);
159
160        let mut buf = BytesMut::new();
161        buf.extend_from_slice(b"\0\x09");
162        assert_eq!(codec.decode(&mut buf), Err(ParseError::MaxSizeExceeded));
163    }
164}