hebo_codec/v3/connect.rs
1// Copyright (c) 2020 Xu Shaohua <shaohua@biofan.org>. All rights reserved.
2// Use of this source is governed by Apache-2.0 License that can be found
3// in the LICENSE file.
4
5use std::convert::TryFrom;
6
7use crate::base::{PROTOCOL_NAME, PROTOCOL_NAME_V3};
8use crate::connect_flags::ConnectFlags;
9use crate::utils::validate_client_id;
10use crate::{
11 validate_keep_alive, BinaryData, ByteArray, DecodeError, DecodePacket, EncodeError,
12 EncodePacket, FixedHeader, KeepAlive, Packet, PacketType, ProtocolLevel, PubTopic, QoS,
13 StringData, VarIntError,
14};
15
16/// `ConnectPacket` consists of three parts:
17/// * `FixedHeader`
18/// * `VariableHeader`
19/// * `Payload`
20/// Note that fixed header part is same in all packets so that we just ignore it.
21///
22/// Basic struct of `ConnectPacket` is as below:
23/// ```txt
24/// 7 0
25/// +----------------------------+
26/// | Fixed header |
27/// | |
28/// +----------------------------+
29/// | Protocol name |
30/// | |
31/// +----------------------------+
32/// | Protocol level |
33/// +----------------------------+
34/// | Connect flags |
35/// +----------------------------+
36/// | Keep alive |
37/// | |
38/// +----------------------------+
39/// | Client id length |
40/// | |
41/// +----------------------------+
42/// | Client id string ... |
43/// +----------------------------+
44/// | Will topic length |
45/// | |
46/// +----------------------------+
47/// | Will topic string ... |
48/// +----------------------------+
49/// | Will message length |
50/// | |
51/// +----------------------------+
52/// | Will message bytes ... |
53/// +----------------------------+
54/// | Username length |
55/// | |
56/// +----------------------------+
57/// | Username string ... |
58/// +----------------------------+
59/// | Password length |
60/// | |
61/// +----------------------------+
62/// | Password bytes ... |
63/// +----------------------------+
64/// ```
65#[allow(clippy::module_name_repetitions)]
66#[derive(Clone, Debug, Default, PartialEq, Eq)]
67pub struct ConnectPacket {
68 /// Protocol name can be `MQTT` in specification for MQTT v3.1.1.
69 ///
70 /// Or `MQIsdp` for MQTT v3.1.
71 protocol_name: StringData,
72
73 protocol_level: ProtocolLevel,
74
75 connect_flags: ConnectFlags,
76
77 /// Time interval between two packets in seconds.
78 /// Client must send PingRequest Packet before exceeding this interval.
79 /// If this value is not zero and time exceeds after last packet, the Server
80 /// will disconnect the network.
81 ///
82 /// If this value is zero, the Server is not required to disconnect the network.
83 keep_alive: KeepAlive,
84
85 /// Payload is `client_id`.
86 /// `client_id` is generated in client side. Normally it can be `device_id` or just
87 /// randomly generated string.
88 /// `client_id` is used to identify client connections in server. Session is based on this field.
89 /// It must be valid UTF-8 string, length shall be between 1 and 23 bytes.
90 /// It can only contain the characters: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
91 /// If `client_id` is invalid, the Server will reply ConnectAck Packet with return code
92 /// 0x02(Identifier rejected).
93 client_id: StringData,
94
95 /// If the `will` flag is true in `connect_flags`, then `will_topic` field must be set.
96 /// It will be used as the topic of Will Message.
97 will_topic: Option<PubTopic>,
98
99 /// If the `will` flag is true in `connect_flags`, then `will_message` field must be set.
100 /// It will be used as the payload of Will Message.
101 /// It consists of 0 to 64k bytes of binary data.
102 will_message: BinaryData,
103
104 /// If the `username` flag is true in `connect_flags`, then `username` field must be set.
105 /// It is a valid UTF-8 string.
106 username: StringData,
107
108 /// If the `password` flag is true in `connect_flags`, then `password` field must be set.
109 /// It consists of 0 to 64k bytes of binary data.
110 password: BinaryData,
111}
112
113impl ConnectPacket {
114 /// Create a new connect packet with `client_id`.
115 ///
116 /// # Errors
117 ///
118 /// Returns error if `client_id` is invalid.
119 pub fn new(client_id: &str) -> Result<Self, EncodeError> {
120 let protocol_name = StringData::from(PROTOCOL_NAME)?;
121 validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
122 let client_id = StringData::from(client_id)?;
123 Ok(Self {
124 protocol_name,
125 keep_alive: KeepAlive::new(60),
126 client_id,
127 ..Self::default()
128 })
129 }
130
131 /// Create a new connect packet with `client_id` with mqtt 3.1 protocol.
132 ///
133 /// # Errors
134 ///
135 /// Returns error if `client_id` is invalid.
136 pub fn new_v3(client_id: &str) -> Result<Self, EncodeError> {
137 let protocol_name = StringData::from(PROTOCOL_NAME_V3)?;
138 let protocol_level = ProtocolLevel::V3;
139 validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
140 let client_id = StringData::from(client_id)?;
141 Ok(Self {
142 protocol_name,
143 protocol_level,
144 keep_alive: KeepAlive::new(60),
145 client_id,
146 ..Self::default()
147 })
148 }
149
150 /// Update protocol level.
151 ///
152 /// # Errors
153 /// Returns error if set protocol to MQTT v5.
154 pub fn set_protcol_level(&mut self, level: ProtocolLevel) -> Result<(), EncodeError> {
155 match level {
156 ProtocolLevel::V3 => {
157 self.protocol_name = StringData::from(PROTOCOL_NAME_V3)?;
158 }
159 ProtocolLevel::V4 => {
160 self.protocol_name = StringData::from(PROTOCOL_NAME)?;
161 }
162 ProtocolLevel::V5 => {
163 return Err(EncodeError::InvalidPacketLevel);
164 }
165 }
166 self.protocol_level = level;
167 Ok(())
168 }
169
170 /// Get current protocol level.
171 #[must_use]
172 #[inline]
173 pub const fn protocol_level(&self) -> ProtocolLevel {
174 self.protocol_level
175 }
176
177 /// Update connect flags
178 pub fn set_connect_flags(&mut self, flags: ConnectFlags) -> &Self {
179 self.connect_flags = flags;
180 self
181 }
182
183 /// Get current connect flags.
184 #[must_use]
185 #[inline]
186 pub const fn connect_flags(&self) -> &ConnectFlags {
187 &self.connect_flags
188 }
189
190 /// Update keep alive value in milliseconds.
191 pub fn set_keep_alive(&mut self, keep_alive: u16) -> &mut Self {
192 self.keep_alive = KeepAlive::new(keep_alive);
193 self
194 }
195
196 /// Get current keep alive value.
197 #[must_use]
198 #[inline]
199 pub const fn keep_alive(&self) -> u16 {
200 // TODO(Shaohua): Returns a duration
201 self.keep_alive.value()
202 }
203
204 /// Update client id.
205 ///
206 /// # Errors
207 ///
208 /// Returns error if `client_id` is invalid.
209 pub fn set_client_id(&mut self, client_id: &str) -> Result<&mut Self, EncodeError> {
210 validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
211 self.client_id = StringData::from(client_id)?;
212 Ok(self)
213 }
214
215 /// Get current client id.
216 #[must_use]
217 pub fn client_id(&self) -> &str {
218 self.client_id.as_ref()
219 }
220
221 /// Update username value.
222 ///
223 /// # Errors
224 ///
225 /// Returns error if `username` contains invalid chars or too long.
226 pub fn set_username(&mut self, username: &str) -> Result<&mut Self, EncodeError> {
227 self.username = StringData::from(username)?;
228 Ok(self)
229 }
230
231 /// Get current username value.
232 #[must_use]
233 pub fn username(&self) -> &str {
234 self.username.as_ref()
235 }
236
237 /// Update password value.
238 ///
239 /// # Errors
240 ///
241 /// Returns error if `password` is too long.
242 pub fn set_password(&mut self, password: &[u8]) -> Result<&mut Self, EncodeError> {
243 self.password = BinaryData::from_slice(password)?;
244 Ok(self)
245 }
246
247 /// Get current password value.
248 #[must_use]
249 pub fn password(&self) -> &[u8] {
250 self.password.as_ref()
251 }
252
253 /// Update will-topic.
254 ///
255 /// # Errors
256 ///
257 /// Returns error if `topic` is invalid.
258 pub fn set_will_topic(&mut self, topic: &str) -> Result<&mut Self, EncodeError> {
259 if topic.is_empty() {
260 self.will_topic = None;
261 } else {
262 self.will_topic = Some(PubTopic::new(topic)?);
263 }
264 Ok(self)
265 }
266
267 /// Get current will-topic value.
268 #[must_use]
269 pub fn will_topic(&self) -> Option<&str> {
270 self.will_topic.as_ref().map(AsRef::as_ref)
271 }
272
273 /// Update will-message.
274 ///
275 /// # Errors
276 ///
277 /// Returns error if `message` is too long.
278 pub fn set_will_message(&mut self, message: &[u8]) -> Result<&mut Self, EncodeError> {
279 self.will_message = BinaryData::from_slice(message)?;
280 Ok(self)
281 }
282
283 /// Get current will-message value.
284 #[must_use]
285 pub fn will_message(&self) -> &[u8] {
286 self.will_message.as_ref()
287 }
288
289 // TODO(Shaohua): Add more getters/setters.
290
291 fn get_fixed_header(&self) -> Result<FixedHeader, VarIntError> {
292 let mut remaining_length = self.protocol_name.bytes()
293 + ProtocolLevel::bytes()
294 + ConnectFlags::bytes()
295 + KeepAlive::bytes()
296 + self.client_id.bytes();
297
298 // Check username/password/topic/message.
299 if self.connect_flags.will() {
300 assert!(self.will_topic.is_some());
301 if let Some(will_topic) = &self.will_topic {
302 remaining_length += will_topic.bytes();
303 }
304 remaining_length += self.will_message.bytes();
305 }
306 if self.connect_flags.has_username() {
307 remaining_length += self.username.bytes();
308 }
309 if self.connect_flags.has_password() {
310 remaining_length += self.password.bytes();
311 }
312 FixedHeader::new(PacketType::Connect, remaining_length)
313 }
314}
315
316impl EncodePacket for ConnectPacket {
317 fn encode(&self, v: &mut Vec<u8>) -> Result<usize, EncodeError> {
318 let old_len = v.len();
319
320 // Write fixed header
321 let fixed_header = self.get_fixed_header()?;
322 fixed_header.encode(v)?;
323
324 // Write variable header
325 self.protocol_name.encode(v)?;
326 self.protocol_level.encode(v)?;
327 self.connect_flags.encode(v)?;
328 self.keep_alive.encode(v)?;
329
330 // Write payload
331 self.client_id.encode(v)?;
332 if self.connect_flags.will() {
333 assert!(self.will_topic.is_some());
334 if let Some(will_topic) = &self.will_topic {
335 will_topic.encode(v)?;
336 }
337
338 self.will_message.encode(v)?;
339 }
340 if self.connect_flags.has_username() {
341 self.username.encode(v)?;
342 }
343 if self.connect_flags.has_password() {
344 self.password.encode(v)?;
345 }
346
347 Ok(v.len() - old_len)
348 }
349}
350
351impl DecodePacket for ConnectPacket {
352 fn decode(ba: &mut ByteArray) -> Result<Self, DecodeError> {
353 let fixed_header = FixedHeader::decode(ba)?;
354 if fixed_header.packet_type() != PacketType::Connect {
355 return Err(DecodeError::InvalidPacketType);
356 }
357
358 let protocol_name = StringData::decode(ba)?;
359 let protocol_level = ProtocolLevel::try_from(ba.read_byte()?)?;
360 match protocol_level {
361 ProtocolLevel::V3 => {
362 if protocol_name.as_ref() != PROTOCOL_NAME_V3 {
363 return Err(DecodeError::InvalidProtocolName);
364 }
365 }
366 ProtocolLevel::V4 => {
367 if protocol_name.as_ref() != PROTOCOL_NAME {
368 return Err(DecodeError::InvalidProtocolName);
369 }
370 }
371 ProtocolLevel::V5 => {
372 return Err(DecodeError::InvalidProtocolLevel);
373 }
374 }
375
376 let connect_flags = ConnectFlags::decode(ba)?;
377 // If the Will Flag is set to 0 the Will QoS and Will Retain fields in the
378 // Connect Flags MUST be set to zero and the Will Topic and Will Message fields
379 // MUST NOT be present in the payload [MQTT-3.1.2-11].
380 //
381 // If the Will Flag is set to 0, then the Will QoS MUST be set to 0 (0x00) [MQTT-3.1.2-13].
382 //
383 // If the Will Flag is set to 1, the value of Will QoS can be 0 (0x00), 1 (0x01), or 2 (0x02).
384 // It MUST NOT be 3 (0x03) [MQTT-3.1.2-14].
385 if !connect_flags.will()
386 && (connect_flags.will_qos() != QoS::AtMostOnce || connect_flags.will_retain())
387 {
388 return Err(DecodeError::InvalidConnectFlags);
389 }
390
391 // If the User Name Flag is set to 0, the Password Flag MUST be set to 0 [MQTT-3.1.2-22].
392 if !connect_flags.has_username() && connect_flags.has_password() {
393 return Err(DecodeError::InvalidConnectFlags);
394 }
395
396 let keep_alive = KeepAlive::decode(ba)?;
397 validate_keep_alive(keep_alive)?;
398
399 // A Server MAY allow a Client to supply a ClientId that has a length of zero bytes,
400 // however if it does so the Server MUST treat this as a special case and assign
401 // a unique ClientId to that Client. It MUST then process the CONNECT packet
402 // as if the Client had provided that unique ClientId [MQTT-3.1.3-6].
403 let client_id = StringData::decode(ba).map_err(|_err| DecodeError::InvalidClientId)?;
404
405 // If the Client supplies a zero-byte ClientId, the Client MUST also set CleanSession
406 // to 1 [MQTT-3.1.3-7].
407 //
408 // If the Client supplies a zero-byte ClientId with CleanSession set to 0, the Server
409 // MUST respond to the CONNECT Packet with a CONNACK return code 0x02 (Identifier rejected)
410 // and then close the Network Connection [MQTT-3.1.3-8].
411 if client_id.is_empty() && !connect_flags.clean_session() {
412 return Err(DecodeError::InvalidClientId);
413 }
414 validate_client_id(client_id.as_ref())?;
415
416 let will_topic = if connect_flags.will() {
417 Some(PubTopic::decode(ba)?)
418 } else {
419 None
420 };
421 let will_message = if connect_flags.will() {
422 BinaryData::decode(ba)?
423 } else {
424 BinaryData::new()
425 };
426
427 let username = if connect_flags.has_username() {
428 StringData::decode(ba)?
429 } else {
430 StringData::new()
431 };
432
433 let password = if connect_flags.has_password() {
434 BinaryData::decode(ba)?
435 } else {
436 BinaryData::new()
437 };
438
439 Ok(Self {
440 protocol_name,
441 protocol_level,
442 connect_flags,
443 keep_alive,
444 client_id,
445 will_topic,
446 will_message,
447 username,
448 password,
449 })
450 }
451}
452
453impl Packet for ConnectPacket {
454 fn packet_type(&self) -> PacketType {
455 PacketType::Connect
456 }
457
458 fn bytes(&self) -> Result<usize, VarIntError> {
459 let fixed_header = self.get_fixed_header()?;
460 Ok(fixed_header.bytes() + fixed_header.remaining_length())
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::{ByteArray, ConnectPacket, DecodePacket};
467
468 #[test]
469 fn test_decode() {
470 let buf: Vec<u8> = vec![
471 16, 20, 0, 4, 77, 81, 84, 84, 4, 2, 0, 60, 0, 8, 119, 118, 80, 84, 88, 99, 67, 119,
472 ];
473 let mut ba = ByteArray::new(&buf);
474 let packet = ConnectPacket::decode(&mut ba);
475 assert!(packet.is_ok());
476 let packet = packet.unwrap();
477 assert_eq!(packet.client_id(), "wvPTXcCw");
478 }
479}