hebo_codec/v3/
connect.rs

1// Copyright (c) 2020 Xu Shaohua <shaohua@biofan.org>. All rights reserved.
2// Use of this source is governed by Apache-2.0 License that can be found
3// in the LICENSE file.
4
5use std::convert::TryFrom;
6
7use crate::base::{PROTOCOL_NAME, PROTOCOL_NAME_V3};
8use crate::connect_flags::ConnectFlags;
9use crate::utils::validate_client_id;
10use crate::{
11    validate_keep_alive, BinaryData, ByteArray, DecodeError, DecodePacket, EncodeError,
12    EncodePacket, FixedHeader, KeepAlive, Packet, PacketType, ProtocolLevel, PubTopic, QoS,
13    StringData, VarIntError,
14};
15
16/// `ConnectPacket` consists of three parts:
17/// * `FixedHeader`
18/// * `VariableHeader`
19/// * `Payload`
20/// Note that fixed header part is same in all packets so that we just ignore it.
21///
22/// Basic struct of `ConnectPacket` is as below:
23/// ```txt
24///  7                          0
25/// +----------------------------+
26/// | Fixed header               |
27/// |                            |
28/// +----------------------------+
29/// | Protocol name              |
30/// |                            |
31/// +----------------------------+
32/// | Protocol level             |
33/// +----------------------------+
34/// | Connect flags              |
35/// +----------------------------+
36/// | Keep alive                 |
37/// |                            |
38/// +----------------------------+
39/// | Client id length           |
40/// |                            |
41/// +----------------------------+
42/// | Client id string ...       |
43/// +----------------------------+
44/// | Will topic length          |
45/// |                            |
46/// +----------------------------+
47/// | Will topic string ...      |
48/// +----------------------------+
49/// | Will message length        |
50/// |                            |
51/// +----------------------------+
52/// | Will message bytes ...     |
53/// +----------------------------+
54/// | Username length            |
55/// |                            |
56/// +----------------------------+
57/// | Username string ...        |
58/// +----------------------------+
59/// | Password length            |
60/// |                            |
61/// +----------------------------+
62/// | Password bytes ...         |
63/// +----------------------------+
64/// ```
65#[allow(clippy::module_name_repetitions)]
66#[derive(Clone, Debug, Default, PartialEq, Eq)]
67pub struct ConnectPacket {
68    /// Protocol name can be `MQTT` in specification for MQTT v3.1.1.
69    ///
70    /// Or `MQIsdp` for MQTT v3.1.
71    protocol_name: StringData,
72
73    protocol_level: ProtocolLevel,
74
75    connect_flags: ConnectFlags,
76
77    /// Time interval between two packets in seconds.
78    /// Client must send PingRequest Packet before exceeding this interval.
79    /// If this value is not zero and time exceeds after last packet, the Server
80    /// will disconnect the network.
81    ///
82    /// If this value is zero, the Server is not required to disconnect the network.
83    keep_alive: KeepAlive,
84
85    /// Payload is `client_id`.
86    /// `client_id` is generated in client side. Normally it can be `device_id` or just
87    /// randomly generated string.
88    /// `client_id` is used to identify client connections in server. Session is based on this field.
89    /// It must be valid UTF-8 string, length shall be between 1 and 23 bytes.
90    /// It can only contain the characters: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
91    /// If `client_id` is invalid, the Server will reply ConnectAck Packet with return code
92    /// 0x02(Identifier rejected).
93    client_id: StringData,
94
95    /// If the `will` flag is true in `connect_flags`, then `will_topic` field must be set.
96    /// It will be used as the topic of Will Message.
97    will_topic: Option<PubTopic>,
98
99    /// If the `will` flag is true in `connect_flags`, then `will_message` field must be set.
100    /// It will be used as the payload of Will Message.
101    /// It consists of 0 to 64k bytes of binary data.
102    will_message: BinaryData,
103
104    /// If the `username` flag is true in `connect_flags`, then `username` field must be set.
105    /// It is a valid UTF-8 string.
106    username: StringData,
107
108    /// If the `password` flag is true in `connect_flags`, then `password` field must be set.
109    /// It consists of 0 to 64k bytes of binary data.
110    password: BinaryData,
111}
112
113impl ConnectPacket {
114    /// Create a new connect packet with `client_id`.
115    ///
116    /// # Errors
117    ///
118    /// Returns error if `client_id` is invalid.
119    pub fn new(client_id: &str) -> Result<Self, EncodeError> {
120        let protocol_name = StringData::from(PROTOCOL_NAME)?;
121        validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
122        let client_id = StringData::from(client_id)?;
123        Ok(Self {
124            protocol_name,
125            keep_alive: KeepAlive::new(60),
126            client_id,
127            ..Self::default()
128        })
129    }
130
131    /// Create a new connect packet with `client_id` with mqtt 3.1 protocol.
132    ///
133    /// # Errors
134    ///
135    /// Returns error if `client_id` is invalid.
136    pub fn new_v3(client_id: &str) -> Result<Self, EncodeError> {
137        let protocol_name = StringData::from(PROTOCOL_NAME_V3)?;
138        let protocol_level = ProtocolLevel::V3;
139        validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
140        let client_id = StringData::from(client_id)?;
141        Ok(Self {
142            protocol_name,
143            protocol_level,
144            keep_alive: KeepAlive::new(60),
145            client_id,
146            ..Self::default()
147        })
148    }
149
150    /// Update protocol level.
151    ///
152    /// # Errors
153    /// Returns error if set protocol to MQTT v5.
154    pub fn set_protcol_level(&mut self, level: ProtocolLevel) -> Result<(), EncodeError> {
155        match level {
156            ProtocolLevel::V3 => {
157                self.protocol_name = StringData::from(PROTOCOL_NAME_V3)?;
158            }
159            ProtocolLevel::V4 => {
160                self.protocol_name = StringData::from(PROTOCOL_NAME)?;
161            }
162            ProtocolLevel::V5 => {
163                return Err(EncodeError::InvalidPacketLevel);
164            }
165        }
166        self.protocol_level = level;
167        Ok(())
168    }
169
170    /// Get current protocol level.
171    #[must_use]
172    #[inline]
173    pub const fn protocol_level(&self) -> ProtocolLevel {
174        self.protocol_level
175    }
176
177    /// Update connect flags
178    pub fn set_connect_flags(&mut self, flags: ConnectFlags) -> &Self {
179        self.connect_flags = flags;
180        self
181    }
182
183    /// Get current connect flags.
184    #[must_use]
185    #[inline]
186    pub const fn connect_flags(&self) -> &ConnectFlags {
187        &self.connect_flags
188    }
189
190    /// Update keep alive value in milliseconds.
191    pub fn set_keep_alive(&mut self, keep_alive: u16) -> &mut Self {
192        self.keep_alive = KeepAlive::new(keep_alive);
193        self
194    }
195
196    /// Get current keep alive value.
197    #[must_use]
198    #[inline]
199    pub const fn keep_alive(&self) -> u16 {
200        // TODO(Shaohua): Returns a duration
201        self.keep_alive.value()
202    }
203
204    /// Update client id.
205    ///
206    /// # Errors
207    ///
208    /// Returns error if `client_id` is invalid.
209    pub fn set_client_id(&mut self, client_id: &str) -> Result<&mut Self, EncodeError> {
210        validate_client_id(client_id).map_err(|_err| EncodeError::InvalidClientId)?;
211        self.client_id = StringData::from(client_id)?;
212        Ok(self)
213    }
214
215    /// Get current client id.
216    #[must_use]
217    pub fn client_id(&self) -> &str {
218        self.client_id.as_ref()
219    }
220
221    /// Update username value.
222    ///
223    /// # Errors
224    ///
225    /// Returns error if `username` contains invalid chars or too long.
226    pub fn set_username(&mut self, username: &str) -> Result<&mut Self, EncodeError> {
227        self.username = StringData::from(username)?;
228        Ok(self)
229    }
230
231    /// Get current username value.
232    #[must_use]
233    pub fn username(&self) -> &str {
234        self.username.as_ref()
235    }
236
237    /// Update password value.
238    ///
239    /// # Errors
240    ///
241    /// Returns error if `password` is too long.
242    pub fn set_password(&mut self, password: &[u8]) -> Result<&mut Self, EncodeError> {
243        self.password = BinaryData::from_slice(password)?;
244        Ok(self)
245    }
246
247    /// Get current password value.
248    #[must_use]
249    pub fn password(&self) -> &[u8] {
250        self.password.as_ref()
251    }
252
253    /// Update will-topic.
254    ///
255    /// # Errors
256    ///
257    /// Returns error if `topic` is invalid.
258    pub fn set_will_topic(&mut self, topic: &str) -> Result<&mut Self, EncodeError> {
259        if topic.is_empty() {
260            self.will_topic = None;
261        } else {
262            self.will_topic = Some(PubTopic::new(topic)?);
263        }
264        Ok(self)
265    }
266
267    /// Get current will-topic value.
268    #[must_use]
269    pub fn will_topic(&self) -> Option<&str> {
270        self.will_topic.as_ref().map(AsRef::as_ref)
271    }
272
273    /// Update will-message.
274    ///
275    /// # Errors
276    ///
277    /// Returns error if `message` is too long.
278    pub fn set_will_message(&mut self, message: &[u8]) -> Result<&mut Self, EncodeError> {
279        self.will_message = BinaryData::from_slice(message)?;
280        Ok(self)
281    }
282
283    /// Get current will-message value.
284    #[must_use]
285    pub fn will_message(&self) -> &[u8] {
286        self.will_message.as_ref()
287    }
288
289    // TODO(Shaohua): Add more getters/setters.
290
291    fn get_fixed_header(&self) -> Result<FixedHeader, VarIntError> {
292        let mut remaining_length = self.protocol_name.bytes()
293            + ProtocolLevel::bytes()
294            + ConnectFlags::bytes()
295            + KeepAlive::bytes()
296            + self.client_id.bytes();
297
298        // Check username/password/topic/message.
299        if self.connect_flags.will() {
300            assert!(self.will_topic.is_some());
301            if let Some(will_topic) = &self.will_topic {
302                remaining_length += will_topic.bytes();
303            }
304            remaining_length += self.will_message.bytes();
305        }
306        if self.connect_flags.has_username() {
307            remaining_length += self.username.bytes();
308        }
309        if self.connect_flags.has_password() {
310            remaining_length += self.password.bytes();
311        }
312        FixedHeader::new(PacketType::Connect, remaining_length)
313    }
314}
315
316impl EncodePacket for ConnectPacket {
317    fn encode(&self, v: &mut Vec<u8>) -> Result<usize, EncodeError> {
318        let old_len = v.len();
319
320        // Write fixed header
321        let fixed_header = self.get_fixed_header()?;
322        fixed_header.encode(v)?;
323
324        // Write variable header
325        self.protocol_name.encode(v)?;
326        self.protocol_level.encode(v)?;
327        self.connect_flags.encode(v)?;
328        self.keep_alive.encode(v)?;
329
330        // Write payload
331        self.client_id.encode(v)?;
332        if self.connect_flags.will() {
333            assert!(self.will_topic.is_some());
334            if let Some(will_topic) = &self.will_topic {
335                will_topic.encode(v)?;
336            }
337
338            self.will_message.encode(v)?;
339        }
340        if self.connect_flags.has_username() {
341            self.username.encode(v)?;
342        }
343        if self.connect_flags.has_password() {
344            self.password.encode(v)?;
345        }
346
347        Ok(v.len() - old_len)
348    }
349}
350
351impl DecodePacket for ConnectPacket {
352    fn decode(ba: &mut ByteArray) -> Result<Self, DecodeError> {
353        let fixed_header = FixedHeader::decode(ba)?;
354        if fixed_header.packet_type() != PacketType::Connect {
355            return Err(DecodeError::InvalidPacketType);
356        }
357
358        let protocol_name = StringData::decode(ba)?;
359        let protocol_level = ProtocolLevel::try_from(ba.read_byte()?)?;
360        match protocol_level {
361            ProtocolLevel::V3 => {
362                if protocol_name.as_ref() != PROTOCOL_NAME_V3 {
363                    return Err(DecodeError::InvalidProtocolName);
364                }
365            }
366            ProtocolLevel::V4 => {
367                if protocol_name.as_ref() != PROTOCOL_NAME {
368                    return Err(DecodeError::InvalidProtocolName);
369                }
370            }
371            ProtocolLevel::V5 => {
372                return Err(DecodeError::InvalidProtocolLevel);
373            }
374        }
375
376        let connect_flags = ConnectFlags::decode(ba)?;
377        // If the Will Flag is set to 0 the Will QoS and Will Retain fields in the
378        // Connect Flags MUST be set to zero and the Will Topic and Will Message fields
379        // MUST NOT be present in the payload [MQTT-3.1.2-11].
380        //
381        // If the Will Flag is set to 0, then the Will QoS MUST be set to 0 (0x00) [MQTT-3.1.2-13].
382        //
383        // If the Will Flag is set to 1, the value of Will QoS can be 0 (0x00), 1 (0x01), or 2 (0x02).
384        // It MUST NOT be 3 (0x03) [MQTT-3.1.2-14].
385        if !connect_flags.will()
386            && (connect_flags.will_qos() != QoS::AtMostOnce || connect_flags.will_retain())
387        {
388            return Err(DecodeError::InvalidConnectFlags);
389        }
390
391        // If the User Name Flag is set to 0, the Password Flag MUST be set to 0 [MQTT-3.1.2-22].
392        if !connect_flags.has_username() && connect_flags.has_password() {
393            return Err(DecodeError::InvalidConnectFlags);
394        }
395
396        let keep_alive = KeepAlive::decode(ba)?;
397        validate_keep_alive(keep_alive)?;
398
399        // A Server MAY allow a Client to supply a ClientId that has a length of zero bytes,
400        // however if it does so the Server MUST treat this as a special case and assign
401        // a unique ClientId to that Client. It MUST then process the CONNECT packet
402        // as if the Client had provided that unique ClientId [MQTT-3.1.3-6].
403        let client_id = StringData::decode(ba).map_err(|_err| DecodeError::InvalidClientId)?;
404
405        // If the Client supplies a zero-byte ClientId, the Client MUST also set CleanSession
406        // to 1 [MQTT-3.1.3-7].
407        //
408        // If the Client supplies a zero-byte ClientId with CleanSession set to 0, the Server
409        // MUST respond to the CONNECT Packet with a CONNACK return code 0x02 (Identifier rejected)
410        // and then close the Network Connection [MQTT-3.1.3-8].
411        if client_id.is_empty() && !connect_flags.clean_session() {
412            return Err(DecodeError::InvalidClientId);
413        }
414        validate_client_id(client_id.as_ref())?;
415
416        let will_topic = if connect_flags.will() {
417            Some(PubTopic::decode(ba)?)
418        } else {
419            None
420        };
421        let will_message = if connect_flags.will() {
422            BinaryData::decode(ba)?
423        } else {
424            BinaryData::new()
425        };
426
427        let username = if connect_flags.has_username() {
428            StringData::decode(ba)?
429        } else {
430            StringData::new()
431        };
432
433        let password = if connect_flags.has_password() {
434            BinaryData::decode(ba)?
435        } else {
436            BinaryData::new()
437        };
438
439        Ok(Self {
440            protocol_name,
441            protocol_level,
442            connect_flags,
443            keep_alive,
444            client_id,
445            will_topic,
446            will_message,
447            username,
448            password,
449        })
450    }
451}
452
453impl Packet for ConnectPacket {
454    fn packet_type(&self) -> PacketType {
455        PacketType::Connect
456    }
457
458    fn bytes(&self) -> Result<usize, VarIntError> {
459        let fixed_header = self.get_fixed_header()?;
460        Ok(fixed_header.bytes() + fixed_header.remaining_length())
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::{ByteArray, ConnectPacket, DecodePacket};
467
468    #[test]
469    fn test_decode() {
470        let buf: Vec<u8> = vec![
471            16, 20, 0, 4, 77, 81, 84, 84, 4, 2, 0, 60, 0, 8, 119, 118, 80, 84, 88, 99, 67, 119,
472        ];
473        let mut ba = ByteArray::new(&buf);
474        let packet = ConnectPacket::decode(&mut ba);
475        assert!(packet.is_ok());
476        let packet = packet.unwrap();
477        assert_eq!(packet.client_id(), "wvPTXcCw");
478    }
479}