embedded_mqtt/variable_header/
connect.rs1use core::{
2 fmt::Debug,
3 convert::{TryInto, TryFrom, From},
4 result::Result,
5};
6
7use crate::{
8 fixed_header::PacketFlags,
9 codec::{self, Encodable},
10 status::Status,
11 error::{DecodeError, EncodeError},
12 qos,
13};
14
15use super::HeaderDecode;
16
17use bitfield::BitRange;
18
19#[derive(PartialEq, Eq, Debug, Clone, Copy)]
20pub enum Protocol {
21 MQTT,
22}
23
24impl Protocol {
25 fn name(self) -> &'static str {
26 match self {
27 Protocol::MQTT => "MQTT",
28 }
29 }
30}
31
32#[derive(PartialEq, Eq, Debug, Clone, Copy)]
33pub enum Level {
34 Level3_1_1,
35}
36
37impl TryFrom<u8> for Level {
38 type Error = ();
39 fn try_from(val: u8) -> Result<Self, Self::Error> {
40 if val == 4 {
41 Ok(Level::Level3_1_1)
42 } else {
43 Err(())
44 }
45 }
46}
47
48impl From<Level> for u8 {
49 fn from(val: Level) -> u8 {
50 match val {
51 Level::Level3_1_1 => 4,
52 }
53 }
54}
55
56#[derive(PartialEq, Clone, Copy, Default)]
57pub struct Flags(u8);
58
59bitfield_bitrange! {
60 struct Flags(u8)
61}
62
63impl Flags {
64 bitfield_fields! {
65 bool;
66 pub has_username, set_has_username : 7;
67 pub has_password, set_has_password : 6;
68 pub will_retain, set_will_retain : 5;
69
70 pub has_will, set_has_will_flag : 2;
71 pub clean_session, set_clean_session : 1;
72 }
73
74 pub fn will_qos(&self) -> Result<qos::QoS, qos::Error> {
75 let qos_bits: u8 = self.bit_range(4, 3);
76 qos_bits.try_into()
77 }
78
79 #[allow(dead_code)]
80 pub fn set_will_qos(&mut self, qos: qos::QoS) {
81 self.set_bit_range(4, 3, u8::from(qos))
82 }
83}
84
85impl From<Flags> for u8 {
86 fn from(val: Flags) -> u8 {
87 val.0
88 }
89}
90
91impl Debug for Flags {
92 bitfield_debug! {
93 struct Flags;
94 pub has_username, _ : 7;
95 pub has_password, _ : 6;
96 pub will_retain, _ : 5;
97 pub into QoS, will_qos, _ : 4, 3;
98 pub has_will, _ : 2;
99 pub clean_session, _ : 1;
100 }
101}
102
103#[derive(PartialEq, Debug)]
105pub struct Connect<'buf> {
106 name: &'buf str,
107 level: Level,
108 flags: Flags,
109 keep_alive: u16,
110}
111
112impl<'buf> Connect<'buf> {
113 pub fn new(protocol: Protocol, level: Level, flags: Flags, keep_alive: u16) -> Self {
114 let name = protocol.name();
115 Connect {
116 name: name,
117 level,
118 flags,
119 keep_alive
120 }
121 }
122
123 pub fn name(&self) -> &str {
124 self.name
125 }
126
127 pub fn level(&self) -> Level {
128 self.level
129 }
130
131 pub fn flags(&self) -> Flags {
132 self.flags
133 }
134
135 pub fn keep_alive(&self) -> u16 {
136 self.keep_alive
137 }
138}
139
140impl<'buf> HeaderDecode<'buf> for Connect<'buf> {
141 fn decode(_flags: PacketFlags, bytes: &'buf [u8]) -> Result<Status<(usize, Connect<'buf>)>, DecodeError> {
142 let offset = 0;
143
144 let (offset, name) = read!(codec::string::parse_string, bytes, offset);
146
147 let (offset, level) = read!(codec::values::parse_u8, bytes, offset);
149
150 let level = level.try_into().map_err(|_| DecodeError::InvalidProtocolLevel)?;
151 if level != Level::Level3_1_1 {
152 return Err(DecodeError::InvalidProtocolLevel)
153 }
154
155 let (offset, flags) = read!(codec::values::parse_u8, bytes, offset);
157
158 let flags = Flags(flags);
159
160 if let Err(e) = flags.will_qos() {
161 match e {
162 qos::Error::BadPattern => return Err(DecodeError::InvalidConnectFlag),
163 }
164 }
165
166 let (offset, keep_alive) = read!(codec::values::parse_u16, bytes, offset);
168
169 Ok(Status::Complete((offset, Connect {
170 name,
171 level,
172 flags,
173 keep_alive,
174 })))
175 }
176}
177
178impl<'buf> Encodable for Connect<'buf> {
179 fn encoded_len(&self) -> usize {
180 self.name.encoded_len() + 1 + 1 + 2
181 }
182
183 fn encode(&self, bytes: &mut [u8]) -> Result<usize, EncodeError> {
184 let offset = 0;
185 let offset = {
186 let o = codec::string::encode_string(self.name, &mut bytes[offset..])?;
187 (offset + o)
188 };
189 let offset = {
190 let o = codec::values::encode_u8(self.level.into(), &mut bytes[offset..])?;
191 (offset + o)
192 };
193 let offset = {
194 let o = codec::values::encode_u8(self.flags.into(), &mut bytes[offset..])?;
195 (offset + o)
196 };
197 let offset = {
198 let o = codec::values::encode_u16(self.keep_alive, &mut bytes[offset..])?;
199 (offset + o)
200 };
201 Ok(offset)
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn parse_flags() {
211 let flags = Flags(0b11100110);
212 assert_eq!(flags.has_username(), true);
213 assert_eq!(flags.has_password(), true);
214 assert_eq!(flags.will_retain(), true);
215 assert_eq!(flags.has_will(), true);
216 assert_eq!(flags.clean_session(), true);
217
218 let flags = Flags(0b00000000);
219 assert_eq!(flags.has_username(), false);
220 assert_eq!(flags.has_password(), false);
221 assert_eq!(flags.will_retain(), false);
222 assert_eq!(flags.has_will(), false);
223 assert_eq!(flags.clean_session(), false);
224 }
225
226 #[test]
227 fn parse_qos() {
228 let flags = Flags(0b00010000);
229 assert_eq!(flags.will_qos(), Ok(qos::QoS::ExactlyOnce));
230
231 let flags = Flags(0b00001000);
232 assert_eq!(flags.will_qos(), Ok(qos::QoS::AtLeastOnce));
233
234 let flags = Flags(0b00000000);
235 assert_eq!(flags.will_qos(), Ok(qos::QoS::AtMostOnce));
236 }
237
238 #[test]
239 fn parse_connect() {
240 let buf = [
241 0b00000000, 0b00000100,
243 0b01001101, 0b01010001, 0b01010100, 0b01010100, 0b00000100, 0b11001110, 0b00000000, 0b00001010, ];
258
259 let connect = Connect::decode(PacketFlags::CONNECT, &buf);
260
261 assert_eq!(connect, Ok(Status::Complete((10, Connect {
262 name: "MQTT",
263 level: Level::Level3_1_1,
264 flags: Flags(0b11001110),
265 keep_alive: 10,
266 }))));
267 }
268}