mqtt_bytes_v5/
connect.rs

1use crate::{mqtt_string_eq, mqtt_string_new, MqttString};
2
3use super::{
4    len_len, length, property, qos, read_mqtt_bytes, read_mqtt_string, read_u16, read_u32, read_u8,
5    write_mqtt_bytes, write_mqtt_string, write_remaining_length, BufMut, BytesMut, Debug, Error,
6    FixedHeader, PropertyType, QoS,
7};
8use bytes::{Buf, Bytes};
9
10/// Connection packet initiated by the client
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct Connect {
13    /// Mqtt keep alive time
14    pub keep_alive: u16,
15    /// Client Id
16    pub client_id: MqttString,
17    /// Clean session. Asks the broker to clear previous state
18    pub clean_start: bool,
19    pub properties: Option<ConnectProperties>,
20}
21
22impl Connect {
23    #[allow(clippy::type_complexity)]
24    pub fn read(
25        fixed_header: FixedHeader,
26        mut bytes: Bytes,
27    ) -> Result<(Connect, Option<LastWill>, Option<Login>), Error> {
28        let variable_header_index = fixed_header.fixed_header_len;
29        bytes.advance(variable_header_index);
30
31        // Variable header
32        let protocol_name = read_mqtt_string(&mut bytes)?;
33        let protocol_level = read_u8(&mut bytes)?;
34        if !mqtt_string_eq(&protocol_name, "MQTT") {
35            return Err(Error::InvalidProtocol);
36        }
37
38        if protocol_level != 5 {
39            return Err(Error::InvalidProtocolLevel(protocol_level));
40        }
41
42        let connect_flags = read_u8(&mut bytes)?;
43        let clean_start = (connect_flags & 0b10) != 0;
44        let keep_alive = read_u16(&mut bytes)?;
45
46        let properties = ConnectProperties::read(&mut bytes)?;
47
48        let client_id = read_mqtt_string(&mut bytes)?;
49        let will = LastWill::read(connect_flags, &mut bytes)?;
50        let login = Login::read(connect_flags, &mut bytes)?;
51
52        let connect = Connect {
53            keep_alive,
54            client_id,
55            clean_start,
56            properties,
57        };
58
59        Ok((connect, will, login))
60    }
61
62    fn len(&self, will: &Option<LastWill>, l: &Option<Login>) -> usize {
63        let mut len = 2 + "MQTT".len() // protocol name
64                        + 1            // protocol version
65                        + 1            // connect flags
66                        + 2; // keep alive
67
68        if let Some(p) = &self.properties {
69            let properties_len = p.len();
70            let properties_len_len = len_len(properties_len);
71            len += properties_len_len + properties_len;
72        } else {
73            // just 1 byte representing 0 len
74            len += 1;
75        }
76
77        len += 2 + self.client_id.len();
78
79        // last will len
80        if let Some(w) = will {
81            len += w.len();
82        }
83
84        // username and password len
85        if let Some(l) = l {
86            len += l.len();
87        }
88
89        len
90    }
91
92    pub fn write(
93        &self,
94        will: &Option<LastWill>,
95        l: &Option<Login>,
96        buffer: &mut BytesMut,
97    ) -> Result<usize, Error> {
98        let len = self.len(will, l);
99
100        buffer.put_u8(0b0001_0000);
101        let count = write_remaining_length(buffer, len)?;
102        write_mqtt_string(buffer, &mqtt_string_new("MQTT"))?;
103
104        buffer.put_u8(0x05);
105        let flags_index = 1 + count + 2 + 4 + 1;
106
107        let mut connect_flags = 0;
108        if self.clean_start {
109            connect_flags |= 0x02;
110        }
111
112        buffer.put_u8(connect_flags);
113        buffer.put_u16(self.keep_alive);
114
115        match &self.properties {
116            Some(p) => p.write(buffer)?,
117            None => {
118                write_remaining_length(buffer, 0)?;
119            }
120        };
121
122        write_mqtt_string(buffer, &self.client_id)?;
123
124        if let Some(w) = will {
125            connect_flags |= w.write(buffer)?;
126        }
127
128        if let Some(l) = l {
129            connect_flags |= l.write(buffer)?;
130        }
131
132        // update connect flags
133        buffer[flags_index] = connect_flags;
134        Ok(1 + count + len)
135    }
136}
137
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct ConnectProperties {
140    /// Expiry interval property after loosing connection
141    pub session_expiry_interval: Option<u32>,
142    /// Maximum simultaneous packets
143    pub receive_maximum: Option<u16>,
144    /// Maximum packet size
145    pub max_packet_size: Option<u32>,
146    /// Maximum mapping integer for a topic
147    pub topic_alias_max: Option<u16>,
148    pub request_response_info: Option<u8>,
149    pub request_problem_info: Option<u8>,
150    /// List of user properties
151    pub user_properties: Vec<(MqttString, MqttString)>,
152    /// Method of authentication
153    pub authentication_method: Option<MqttString>,
154    /// Authentication data
155    pub authentication_data: Option<Bytes>,
156}
157
158impl ConnectProperties {
159    #[must_use]
160    pub fn new() -> ConnectProperties {
161        ConnectProperties {
162            session_expiry_interval: None,
163            receive_maximum: None,
164            max_packet_size: None,
165            topic_alias_max: None,
166            request_response_info: None,
167            request_problem_info: None,
168            user_properties: Vec::new(),
169            authentication_method: None,
170            authentication_data: None,
171        }
172    }
173
174    pub fn read(bytes: &mut Bytes) -> Result<Option<ConnectProperties>, Error> {
175        let mut session_expiry_interval = None;
176        let mut receive_maximum = None;
177        let mut max_packet_size = None;
178        let mut topic_alias_max = None;
179        let mut request_response_info = None;
180        let mut request_problem_info = None;
181        let mut user_properties = Vec::new();
182        let mut authentication_method = None;
183        let mut authentication_data = None;
184
185        let (properties_len_len, properties_len) = length(bytes.iter())?;
186        bytes.advance(properties_len_len);
187        if properties_len == 0 {
188            return Ok(None);
189        }
190
191        let mut cursor = 0;
192        // read until cursor reaches property length. properties_len = 0 will skip this loop
193        while cursor < properties_len {
194            let prop = read_u8(bytes)?;
195            cursor += 1;
196            match property(prop)? {
197                PropertyType::SessionExpiryInterval => {
198                    session_expiry_interval = Some(read_u32(bytes)?);
199                    cursor += 4;
200                }
201                PropertyType::ReceiveMaximum => {
202                    receive_maximum = Some(read_u16(bytes)?);
203                    cursor += 2;
204                }
205                PropertyType::MaximumPacketSize => {
206                    max_packet_size = Some(read_u32(bytes)?);
207                    cursor += 4;
208                }
209                PropertyType::TopicAliasMaximum => {
210                    topic_alias_max = Some(read_u16(bytes)?);
211                    cursor += 2;
212                }
213                PropertyType::RequestResponseInformation => {
214                    request_response_info = Some(read_u8(bytes)?);
215                    cursor += 1;
216                }
217                PropertyType::RequestProblemInformation => {
218                    request_problem_info = Some(read_u8(bytes)?);
219                    cursor += 1;
220                }
221                PropertyType::UserProperty => {
222                    let key = read_mqtt_string(bytes)?;
223                    let value = read_mqtt_string(bytes)?;
224                    cursor += 2 + key.len() + 2 + value.len();
225                    user_properties.push((key, value));
226                }
227                PropertyType::AuthenticationMethod => {
228                    let method = read_mqtt_string(bytes)?;
229                    cursor += 2 + method.len();
230                    authentication_method = Some(method);
231                }
232                PropertyType::AuthenticationData => {
233                    let data = read_mqtt_bytes(bytes)?;
234                    cursor += 2 + data.len();
235                    authentication_data = Some(data);
236                }
237                _ => return Err(Error::InvalidPropertyType(prop)),
238            }
239        }
240
241        Ok(Some(ConnectProperties {
242            session_expiry_interval,
243            receive_maximum,
244            max_packet_size,
245            topic_alias_max,
246            request_response_info,
247            request_problem_info,
248            user_properties,
249            authentication_method,
250            authentication_data,
251        }))
252    }
253
254    fn len(&self) -> usize {
255        let mut len = 0;
256
257        if self.session_expiry_interval.is_some() {
258            len += 1 + 4;
259        }
260
261        if self.receive_maximum.is_some() {
262            len += 1 + 2;
263        }
264
265        if self.max_packet_size.is_some() {
266            len += 1 + 4;
267        }
268
269        if self.topic_alias_max.is_some() {
270            len += 1 + 2;
271        }
272
273        if self.request_response_info.is_some() {
274            len += 1 + 1;
275        }
276
277        if self.request_problem_info.is_some() {
278            len += 1 + 1;
279        }
280
281        for (key, value) in &self.user_properties {
282            len += 1 + 2 + key.len() + 2 + value.len();
283        }
284
285        if let Some(authentication_method) = &self.authentication_method {
286            len += 1 + 2 + authentication_method.len();
287        }
288
289        if let Some(authentication_data) = &self.authentication_data {
290            len += 1 + 2 + authentication_data.len();
291        }
292
293        len
294    }
295
296    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
297        let len = self.len();
298        write_remaining_length(buffer, len)?;
299
300        if let Some(session_expiry_interval) = self.session_expiry_interval {
301            buffer.put_u8(PropertyType::SessionExpiryInterval as u8);
302            buffer.put_u32(session_expiry_interval);
303        }
304
305        if let Some(receive_maximum) = self.receive_maximum {
306            buffer.put_u8(PropertyType::ReceiveMaximum as u8);
307            buffer.put_u16(receive_maximum);
308        }
309
310        if let Some(max_packet_size) = self.max_packet_size {
311            buffer.put_u8(PropertyType::MaximumPacketSize as u8);
312            buffer.put_u32(max_packet_size);
313        }
314
315        if let Some(topic_alias_max) = self.topic_alias_max {
316            buffer.put_u8(PropertyType::TopicAliasMaximum as u8);
317            buffer.put_u16(topic_alias_max);
318        }
319
320        if let Some(request_response_info) = self.request_response_info {
321            buffer.put_u8(PropertyType::RequestResponseInformation as u8);
322            buffer.put_u8(request_response_info);
323        }
324
325        if let Some(request_problem_info) = self.request_problem_info {
326            buffer.put_u8(PropertyType::RequestProblemInformation as u8);
327            buffer.put_u8(request_problem_info);
328        }
329
330        for (key, value) in &self.user_properties {
331            buffer.put_u8(PropertyType::UserProperty as u8);
332            write_mqtt_string(buffer, key)?;
333            write_mqtt_string(buffer, value)?;
334        }
335
336        if let Some(authentication_method) = &self.authentication_method {
337            buffer.put_u8(PropertyType::AuthenticationMethod as u8);
338            write_mqtt_string(buffer, authentication_method)?;
339        }
340
341        if let Some(authentication_data) = &self.authentication_data {
342            buffer.put_u8(PropertyType::AuthenticationData as u8);
343            write_mqtt_bytes(buffer, authentication_data)?;
344        }
345
346        Ok(())
347    }
348}
349
350impl Default for ConnectProperties {
351    fn default() -> Self {
352        Self::new()
353    }
354}
355
356/// `LastWill` that broker forwards on behalf of the client
357#[derive(Debug, Clone, PartialEq, Eq)]
358pub struct LastWill {
359    pub topic: Bytes,
360    pub message: Bytes,
361    pub qos: QoS,
362    pub retain: bool,
363    pub properties: Option<LastWillProperties>,
364}
365
366impl LastWill {
367    fn len(&self) -> usize {
368        let mut len = 0;
369
370        if let Some(p) = &self.properties {
371            let properties_len = p.len();
372            let properties_len_len = len_len(properties_len);
373            len += properties_len_len + properties_len;
374        } else {
375            // just 1 byte representing 0 len
376            len += 1;
377        }
378
379        len += 2 + self.topic.len() + 2 + self.message.len();
380        len
381    }
382
383    pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<LastWill>, Error> {
384        let o = match connect_flags & 0b100 {
385            0 if (connect_flags & 0b0011_1000) != 0 => {
386                return Err(Error::IncorrectPacketFormat);
387            }
388            0 => None,
389            _ => {
390                // Properties in variable header
391                let properties = LastWillProperties::read(bytes)?;
392
393                let will_topic = read_mqtt_bytes(bytes)?;
394                let will_message = read_mqtt_bytes(bytes)?;
395                let qos_num = (connect_flags & 0b11000) >> 3;
396                let will_qos = qos(qos_num).ok_or(Error::InvalidQoS(qos_num))?;
397                Some(LastWill {
398                    topic: will_topic,
399                    message: will_message,
400                    qos: will_qos,
401                    retain: (connect_flags & 0b0010_0000) != 0,
402                    properties,
403                })
404            }
405        };
406
407        Ok(o)
408    }
409
410    pub fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
411        let mut connect_flags = 0;
412
413        connect_flags |= 0x04 | (self.qos as u8) << 3;
414        if self.retain {
415            connect_flags |= 0x20;
416        }
417
418        if let Some(p) = &self.properties {
419            p.write(buffer)?;
420        } else {
421            write_remaining_length(buffer, 0)?;
422        }
423
424        write_mqtt_bytes(buffer, &self.topic)?;
425        write_mqtt_bytes(buffer, &self.message)?;
426        Ok(connect_flags)
427    }
428}
429
430#[derive(Debug, Clone, PartialEq, Eq)]
431pub struct LastWillProperties {
432    pub delay_interval: Option<u32>,
433    pub payload_format_indicator: Option<u8>,
434    pub message_expiry_interval: Option<u32>,
435    pub content_type: Option<MqttString>,
436    pub response_topic: Option<MqttString>,
437    pub correlation_data: Option<Bytes>,
438    pub user_properties: Vec<(MqttString, MqttString)>,
439}
440
441impl LastWillProperties {
442    fn len(&self) -> usize {
443        let mut len = 0;
444
445        if self.delay_interval.is_some() {
446            len += 1 + 4;
447        }
448
449        if self.payload_format_indicator.is_some() {
450            len += 1 + 1;
451        }
452
453        if self.message_expiry_interval.is_some() {
454            len += 1 + 4;
455        }
456
457        if let Some(typ) = &self.content_type {
458            len += 1 + 2 + typ.len();
459        }
460
461        if let Some(topic) = &self.response_topic {
462            len += 1 + 2 + topic.len();
463        }
464
465        if let Some(data) = &self.correlation_data {
466            len += 1 + 2 + data.len();
467        }
468
469        for (key, value) in &self.user_properties {
470            len += 1 + 2 + key.len() + 2 + value.len();
471        }
472
473        len
474    }
475
476    pub fn read(bytes: &mut Bytes) -> Result<Option<LastWillProperties>, Error> {
477        let mut delay_interval = None;
478        let mut payload_format_indicator = None;
479        let mut message_expiry_interval = None;
480        let mut content_type = None;
481        let mut response_topic = None;
482        let mut correlation_data = None;
483        let mut user_properties = Vec::new();
484
485        let (properties_len_len, properties_len) = length(bytes.iter())?;
486        bytes.advance(properties_len_len);
487        if properties_len == 0 {
488            return Ok(None);
489        }
490
491        let mut cursor = 0;
492        // read until cursor reaches property length. properties_len = 0 will skip this loop
493        while cursor < properties_len {
494            let prop = read_u8(bytes)?;
495            cursor += 1;
496
497            match property(prop)? {
498                PropertyType::WillDelayInterval => {
499                    delay_interval = Some(read_u32(bytes)?);
500                    cursor += 4;
501                }
502                PropertyType::PayloadFormatIndicator => {
503                    payload_format_indicator = Some(read_u8(bytes)?);
504                    cursor += 1;
505                }
506                PropertyType::MessageExpiryInterval => {
507                    message_expiry_interval = Some(read_u32(bytes)?);
508                    cursor += 4;
509                }
510                PropertyType::ContentType => {
511                    let typ = read_mqtt_string(bytes)?;
512                    cursor += 2 + typ.len();
513                    content_type = Some(typ);
514                }
515                PropertyType::ResponseTopic => {
516                    let topic = read_mqtt_string(bytes)?;
517                    cursor += 2 + topic.len();
518                    response_topic = Some(topic);
519                }
520                PropertyType::CorrelationData => {
521                    let data = read_mqtt_bytes(bytes)?;
522                    cursor += 2 + data.len();
523                    correlation_data = Some(data);
524                }
525                PropertyType::UserProperty => {
526                    let key = read_mqtt_string(bytes)?;
527                    let value = read_mqtt_string(bytes)?;
528                    cursor += 2 + key.len() + 2 + value.len();
529                    user_properties.push((key, value));
530                }
531                _ => return Err(Error::InvalidPropertyType(prop)),
532            }
533        }
534
535        Ok(Some(LastWillProperties {
536            delay_interval,
537            payload_format_indicator,
538            message_expiry_interval,
539            content_type,
540            response_topic,
541            correlation_data,
542            user_properties,
543        }))
544    }
545
546    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
547        let len = self.len();
548        write_remaining_length(buffer, len)?;
549
550        if let Some(delay_interval) = self.delay_interval {
551            buffer.put_u8(PropertyType::WillDelayInterval as u8);
552            buffer.put_u32(delay_interval);
553        }
554
555        if let Some(payload_format_indicator) = self.payload_format_indicator {
556            buffer.put_u8(PropertyType::PayloadFormatIndicator as u8);
557            buffer.put_u8(payload_format_indicator);
558        }
559
560        if let Some(message_expiry_interval) = self.message_expiry_interval {
561            buffer.put_u8(PropertyType::MessageExpiryInterval as u8);
562            buffer.put_u32(message_expiry_interval);
563        }
564
565        if let Some(typ) = &self.content_type {
566            buffer.put_u8(PropertyType::ContentType as u8);
567            write_mqtt_string(buffer, typ)?;
568        }
569
570        if let Some(topic) = &self.response_topic {
571            buffer.put_u8(PropertyType::ResponseTopic as u8);
572            write_mqtt_string(buffer, topic)?;
573        }
574
575        if let Some(data) = &self.correlation_data {
576            buffer.put_u8(PropertyType::CorrelationData as u8);
577            write_mqtt_bytes(buffer, data)?;
578        }
579
580        for (key, value) in &self.user_properties {
581            buffer.put_u8(PropertyType::UserProperty as u8);
582            write_mqtt_string(buffer, key)?;
583            write_mqtt_string(buffer, value)?;
584        }
585
586        Ok(())
587    }
588}
589#[derive(Debug, Clone, PartialEq, Eq)]
590pub struct Login {
591    pub username: MqttString,
592    pub password: MqttString,
593}
594
595impl Login {
596    pub fn new<U: Into<MqttString>, P: Into<MqttString>>(u: U, p: P) -> Login {
597        Login {
598            username: u.into(),
599            password: p.into(),
600        }
601    }
602
603    pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<Login>, Error> {
604        let username = match connect_flags & 0b1000_0000 {
605            0 => MqttString::default(),
606            _ => read_mqtt_string(bytes)?,
607        };
608
609        let password = match connect_flags & 0b0100_0000 {
610            0 => MqttString::default(),
611            _ => read_mqtt_string(bytes)?,
612        };
613
614        if username.is_empty() && password.is_empty() {
615            Ok(None)
616        } else {
617            Ok(Some(Login { username, password }))
618        }
619    }
620
621    fn len(&self) -> usize {
622        let mut len = 0;
623
624        if !self.username.is_empty() {
625            len += 2 + self.username.len();
626        }
627
628        if !self.password.is_empty() {
629            len += 2 + self.password.len();
630        }
631
632        len
633    }
634
635    pub fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
636        let mut connect_flags = 0;
637        if !self.username.is_empty() {
638            connect_flags |= 0x80;
639            write_mqtt_string(buffer, &self.username)?;
640        }
641
642        if !self.password.is_empty() {
643            connect_flags |= 0x40;
644            write_mqtt_string(buffer, &self.password)?;
645        }
646
647        Ok(connect_flags)
648    }
649}
650
651#[cfg(test)]
652mod test {
653    use crate::test::read_write_packets;
654    use crate::Packet;
655
656    use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
657    use super::*;
658    use bytes::BytesMut;
659    use pretty_assertions::assert_eq;
660
661    #[test]
662    fn length_calculation() {
663        let mut dummy_bytes = BytesMut::new();
664        let mut connect_props = ConnectProperties::new();
665        // Use user_properties to pad the size to exceed ~128 bytes to make the
666        // remaining_length field in the packet be 2 bytes long.
667        connect_props.user_properties = vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())];
668        let connect_pkt = Connect {
669            keep_alive: 5,
670            client_id: "client".into(),
671            clean_start: true,
672            properties: Some(connect_props),
673        };
674
675        let reported_size = connect_pkt.write(&None, &None, &mut dummy_bytes).unwrap();
676        let size_from_bytes = dummy_bytes.len();
677
678        assert_eq!(reported_size, size_from_bytes);
679    }
680
681    #[test]
682    fn test_write_read() {
683        read_write_packets(write_read_provider());
684    }
685
686    fn write_read_provider() -> Vec<Packet> {
687        vec![
688            Packet::Connect(
689                Connect {
690                    keep_alive: 5,
691                    client_id: "client".into(),
692                    clean_start: true,
693                    properties: None,
694                },
695                None,
696                None,
697            ),
698            Packet::Connect(
699                Connect {
700                    keep_alive: 5,
701                    client_id: "client".into(),
702                    clean_start: true,
703                    properties: Some(ConnectProperties {
704                        session_expiry_interval: Some(5),
705                        receive_maximum: Some(5),
706                        max_packet_size: Some(5),
707                        topic_alias_max: Some(5),
708                        request_response_info: Some(5),
709                        request_problem_info: Some(5),
710                        user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
711                        authentication_method: Some("method".into()),
712                        authentication_data: Some(Bytes::from("data")),
713                    }),
714                },
715                Some(LastWill {
716                    topic: Bytes::from("topic"),
717                    message: Bytes::from("message"),
718                    qos: QoS::AtLeastOnce,
719                    retain: true,
720                    properties: Some(LastWillProperties {
721                        delay_interval: Some(5),
722                        payload_format_indicator: Some(5),
723                        message_expiry_interval: Some(5),
724                        content_type: Some("type".into()),
725                        response_topic: Some("topic".into()),
726                        correlation_data: Some(Bytes::from("data")),
727                        user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
728                    }),
729                }),
730                Some(Login {
731                    username: "username".into(),
732                    password: "password".into(),
733                }),
734            ),
735        ]
736    }
737}