embedded_mqtt/variable_header/
connect.rs

1use core::{
2    fmt::Debug,
3    convert::{TryInto, TryFrom, From},
4    result::Result,
5};
6
7use crate::{
8    fixed_header::PacketFlags,
9    codec::{self, Encodable},
10    status::Status,
11    error::{DecodeError, EncodeError},
12    qos,
13};
14
15use super::HeaderDecode;
16
17use bitfield::BitRange;
18
19#[derive(PartialEq, Eq, Debug, Clone, Copy)]
20pub enum Protocol {
21    MQTT,
22}
23
24impl Protocol {
25    fn name(self) -> &'static str {
26        match self {
27            Protocol::MQTT => "MQTT",
28        }
29    }
30}
31
32#[derive(PartialEq, Eq, Debug, Clone, Copy)]
33pub enum Level {
34    Level3_1_1,
35}
36
37impl TryFrom<u8> for Level {
38    type Error = ();
39    fn try_from(val: u8) -> Result<Self, Self::Error> {
40        if val == 4 {
41            Ok(Level::Level3_1_1)
42        } else {
43            Err(())
44        }
45    }
46}
47
48impl From<Level> for u8 {
49    fn from(val: Level) -> u8 {
50        match val {
51            Level::Level3_1_1 => 4,
52        }
53    }
54}
55
56#[derive(PartialEq, Clone, Copy, Default)]
57pub struct Flags(u8);
58
59bitfield_bitrange! {
60    struct Flags(u8)
61}
62
63impl Flags {
64    bitfield_fields! {
65        bool;
66        pub has_username,  set_has_username  : 7;
67        pub has_password,  set_has_password  : 6;
68        pub will_retain,   set_will_retain   : 5;
69        
70        pub has_will,      set_has_will_flag : 2;
71        pub clean_session, set_clean_session : 1;
72    }
73
74    pub fn will_qos(&self) -> Result<qos::QoS, qos::Error> {
75        let qos_bits: u8 = self.bit_range(4, 3);
76        qos_bits.try_into()
77    }
78
79    #[allow(dead_code)]
80    pub fn set_will_qos(&mut self, qos: qos::QoS) {
81        self.set_bit_range(4, 3, u8::from(qos))
82    }
83}
84
85impl From<Flags> for u8 {
86    fn from(val: Flags) -> u8 {
87        val.0
88    }
89}
90
91impl Debug for Flags {
92    bitfield_debug! {
93        struct Flags;
94        pub has_username, _       : 7;
95        pub has_password, _       : 6;
96        pub will_retain, _        : 5;
97        pub into QoS, will_qos, _ : 4, 3;
98        pub has_will, _           : 2;
99        pub clean_session, _      : 1;
100    }
101}
102
103// VariableHeader for Connect packet
104#[derive(PartialEq, Debug)]
105pub struct Connect<'buf> {
106    name: &'buf str,
107    level: Level,
108    flags: Flags,
109    keep_alive: u16,
110}
111
112impl<'buf> Connect<'buf> {
113    pub fn new(protocol: Protocol, level: Level, flags: Flags, keep_alive: u16) -> Self {
114        let name = protocol.name();
115        Connect {
116            name: name,
117            level,
118            flags,
119            keep_alive
120        }
121    }
122
123    pub fn name(&self) -> &str {
124        self.name
125    }
126
127    pub fn level(&self) -> Level {
128        self.level
129    }
130
131    pub fn flags(&self) -> Flags {
132        self.flags
133    }
134
135    pub fn keep_alive(&self) -> u16 {
136        self.keep_alive
137    }
138}
139
140impl<'buf> HeaderDecode<'buf> for Connect<'buf> {
141    fn decode(_flags: PacketFlags, bytes: &'buf [u8]) -> Result<Status<(usize, Connect<'buf>)>, DecodeError> {
142        let offset = 0;
143
144        // read protocol name
145        let (offset, name) = read!(codec::string::parse_string, bytes, offset);
146
147        // read protocol revision
148        let (offset, level) = read!(codec::values::parse_u8, bytes, offset);
149
150        let level = level.try_into().map_err(|_| DecodeError::InvalidProtocolLevel)?;
151        if level != Level::Level3_1_1 {
152            return Err(DecodeError::InvalidProtocolLevel)
153        }
154
155        // read protocol flags
156        let (offset, flags) = read!(codec::values::parse_u8, bytes, offset);
157
158        let flags = Flags(flags);
159
160        if let Err(e) = flags.will_qos() {
161            match e {
162                qos::Error::BadPattern => return Err(DecodeError::InvalidConnectFlag),
163            }
164        }
165
166        // read protocol keep alive
167        let (offset, keep_alive) = read!(codec::values::parse_u16, bytes, offset);
168
169        Ok(Status::Complete((offset, Connect {
170            name,
171            level,
172            flags,
173            keep_alive,
174        })))
175    }
176}
177
178impl<'buf> Encodable for Connect<'buf> {
179    fn encoded_len(&self) -> usize {
180        self.name.encoded_len() + 1 + 1 + 2
181    }
182
183    fn encode(&self, bytes: &mut [u8]) -> Result<usize, EncodeError> {
184        let offset = 0;
185        let offset = {
186            let o = codec::string::encode_string(self.name, &mut bytes[offset..])?;
187            (offset + o)
188        };
189        let offset = {
190            let o = codec::values::encode_u8(self.level.into(), &mut bytes[offset..])?;
191            (offset + o)
192        };
193        let offset = {
194            let o = codec::values::encode_u8(self.flags.into(), &mut bytes[offset..])?;
195            (offset + o)
196        };
197        let offset = {
198            let o = codec::values::encode_u16(self.keep_alive, &mut bytes[offset..])?;
199            (offset + o)
200        };
201        Ok(offset)
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn parse_flags() {
211        let flags = Flags(0b11100110);
212        assert_eq!(flags.has_username(), true);
213        assert_eq!(flags.has_password(), true);
214        assert_eq!(flags.will_retain(), true);
215        assert_eq!(flags.has_will(), true);
216        assert_eq!(flags.clean_session(), true);
217
218        let flags = Flags(0b00000000);
219        assert_eq!(flags.has_username(), false);
220        assert_eq!(flags.has_password(), false);
221        assert_eq!(flags.will_retain(), false);
222        assert_eq!(flags.has_will(), false);
223        assert_eq!(flags.clean_session(), false);
224    }
225
226    #[test]
227    fn parse_qos() {
228        let flags = Flags(0b00010000);
229        assert_eq!(flags.will_qos(), Ok(qos::QoS::ExactlyOnce));
230
231        let flags = Flags(0b00001000);
232        assert_eq!(flags.will_qos(), Ok(qos::QoS::AtLeastOnce));
233
234        let flags = Flags(0b00000000);
235        assert_eq!(flags.will_qos(), Ok(qos::QoS::AtMostOnce));
236    }
237
238    #[test]
239    fn parse_connect() {
240        let buf = [
241            0b00000000, // Protocol Name Length
242            0b00000100,
243            0b01001101, // 'M'
244            0b01010001, // 'Q'
245            0b01010100, // 'T'
246            0b01010100, // 'T'
247            0b00000100, // Level 4
248            0b11001110, // Connect Flags - Username 1
249                        //               - Password 1
250                        //               - Will Retain 0
251                        //               - Will QoS 01
252                        //               - Will Flag 1
253                        //               - Clean Session 1
254                        //               - Reserved 0
255            0b00000000, // Keep Alive (10s)
256            0b00001010, // 
257        ];
258
259        let connect = Connect::decode(PacketFlags::CONNECT, &buf);
260
261        assert_eq!(connect, Ok(Status::Complete((10, Connect {
262            name: "MQTT",
263            level: Level::Level3_1_1,
264            flags: Flags(0b11001110),
265            keep_alive: 10,
266        }))));
267    }
268}