Skip to main content

rumqttc/mqttbytes/v5/
connect.rs

1use super::{
2    BufMut, BytesMut, Error, FixedHeader, PropertyType, QoS, len_len, length, property, qos,
3    read_mqtt_bytes, read_mqtt_string, read_u8, read_u16, read_u32, write_mqtt_bytes,
4    write_mqtt_string, write_remaining_length,
5};
6use bytes::{Buf, Bytes};
7
8type ConnectReadParts = (Connect, Option<LastWill>, ConnectAuth);
9
10/// Connection packet initiated by the client
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct Connect {
13    /// Mqtt keep alive time
14    pub keep_alive: u16,
15    /// Client Id
16    pub client_id: String,
17    /// Clean session. Asks the broker to clear previous state
18    pub clean_start: bool,
19    pub properties: Option<ConnectProperties>,
20}
21
22impl Connect {
23    pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<ConnectReadParts, Error> {
24        let variable_header_index = fixed_header.header_len;
25        bytes.advance(variable_header_index);
26
27        // Variable header
28        let protocol_name = read_mqtt_string(&mut bytes)?;
29        let protocol_level = read_u8(&mut bytes)?;
30        if protocol_name != "MQTT" {
31            return Err(Error::InvalidProtocol);
32        }
33
34        if protocol_level != 5 {
35            return Err(Error::InvalidProtocolLevel(protocol_level));
36        }
37
38        let connect_flags = read_u8(&mut bytes)?;
39        validate_connect_flags(connect_flags)?;
40        let clean_start = (connect_flags & 0b10) != 0;
41        let keep_alive = read_u16(&mut bytes)?;
42
43        let properties = ConnectProperties::read(&mut bytes)?;
44
45        let client_id = read_mqtt_string(&mut bytes)?;
46        let will = LastWill::read(connect_flags, &mut bytes)?;
47        let auth = ConnectAuth::read(connect_flags, &mut bytes)?;
48
49        let connect = Self {
50            keep_alive,
51            client_id,
52            clean_start,
53            properties,
54        };
55
56        Ok((connect, will, auth))
57    }
58
59    fn len(&self, will: Option<&LastWill>, auth: &ConnectAuth) -> usize {
60        let mut len = 2 + "MQTT".len() // protocol name
61                        + 1            // protocol version
62                        + 1            // connect flags
63                        + 2; // keep alive
64
65        if let Some(p) = &self.properties {
66            let properties_len = p.len();
67            let properties_len_len = len_len(properties_len);
68            len += properties_len_len + properties_len;
69        } else {
70            // just 1 byte representing 0 len
71            len += 1;
72        }
73
74        len += 2 + self.client_id.len();
75
76        // last will len
77        if let Some(w) = will {
78            len += w.len();
79        }
80
81        // username and password len
82        len += auth.len();
83
84        len
85    }
86
87    pub fn write(
88        &self,
89        will: &Option<LastWill>,
90        auth: &ConnectAuth,
91        buffer: &mut BytesMut,
92    ) -> Result<usize, Error> {
93        let len = self.len(will.as_ref(), auth);
94
95        buffer.put_u8(0b0001_0000);
96        let count = write_remaining_length(buffer, len)?;
97        write_mqtt_string(buffer, "MQTT");
98
99        buffer.put_u8(0x05);
100        let flags_index = 1 + count + 2 + 4 + 1;
101
102        let mut connect_flags = 0;
103        if self.clean_start {
104            connect_flags |= 0x02;
105        }
106
107        buffer.put_u8(connect_flags);
108        buffer.put_u16(self.keep_alive);
109
110        match &self.properties {
111            Some(p) => p.write(buffer)?,
112            None => {
113                write_remaining_length(buffer, 0)?;
114            }
115        }
116
117        write_mqtt_string(buffer, &self.client_id);
118
119        if let Some(w) = will {
120            connect_flags |= w.write(buffer)?;
121        }
122
123        connect_flags |= auth.write(buffer);
124
125        // update connect flags
126        buffer[flags_index] = connect_flags;
127        Ok(1 + count + len)
128    }
129
130    pub fn size(&self, will: &Option<LastWill>, auth: &ConnectAuth) -> usize {
131        let len = self.len(will.as_ref(), auth);
132        let remaining_len_size = len_len(len);
133
134        1 + remaining_len_size + len
135    }
136}
137
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct ConnectProperties {
140    /// Expiry interval property after loosing connection
141    pub session_expiry_interval: Option<u32>,
142    /// Maximum simultaneous packets
143    pub receive_maximum: Option<u16>,
144    /// Maximum packet size
145    pub max_packet_size: Option<u32>,
146    /// Maximum mapping integer for a topic
147    pub topic_alias_max: Option<u16>,
148    pub request_response_info: Option<u8>,
149    pub request_problem_info: Option<u8>,
150    /// List of user properties
151    pub user_properties: Vec<(String, String)>,
152    /// Method of authentication
153    pub authentication_method: Option<String>,
154    /// Authentication data
155    pub authentication_data: Option<Bytes>,
156}
157
158impl ConnectProperties {
159    #[must_use]
160    pub const fn new() -> Self {
161        Self {
162            session_expiry_interval: None,
163            receive_maximum: None,
164            max_packet_size: None,
165            topic_alias_max: None,
166            request_response_info: None,
167            request_problem_info: None,
168            user_properties: Vec::new(),
169            authentication_method: None,
170            authentication_data: None,
171        }
172    }
173
174    pub fn read(bytes: &mut Bytes) -> Result<Option<Self>, Error> {
175        let mut session_expiry_interval = None;
176        let mut receive_maximum = None;
177        let mut max_packet_size = None;
178        let mut topic_alias_max = None;
179        let mut request_response_info = None;
180        let mut request_problem_info = None;
181        let mut user_properties = Vec::new();
182        let mut authentication_method = None;
183        let mut authentication_data = None;
184
185        let (properties_len_len, properties_len) = length(bytes.iter())?;
186        bytes.advance(properties_len_len);
187        if properties_len == 0 {
188            return Ok(None);
189        }
190
191        let mut cursor = 0;
192        // read until cursor reaches property length. properties_len = 0 will skip this loop
193        while cursor < properties_len {
194            let prop = read_u8(bytes)?;
195            cursor += 1;
196            match property(prop)? {
197                PropertyType::SessionExpiryInterval => {
198                    session_expiry_interval = Some(read_u32(bytes)?);
199                    cursor += 4;
200                }
201                PropertyType::ReceiveMaximum => {
202                    let receive_max = read_u16(bytes)?;
203                    if receive_max == 0 {
204                        return Err(Error::ProtocolError);
205                    }
206                    receive_maximum = Some(receive_max);
207                    cursor += 2;
208                }
209                PropertyType::MaximumPacketSize => {
210                    max_packet_size = Some(read_u32(bytes)?);
211                    cursor += 4;
212                }
213                PropertyType::TopicAliasMaximum => {
214                    topic_alias_max = Some(read_u16(bytes)?);
215                    cursor += 2;
216                }
217                PropertyType::RequestResponseInformation => {
218                    request_response_info = Some(read_u8(bytes)?);
219                    cursor += 1;
220                }
221                PropertyType::RequestProblemInformation => {
222                    request_problem_info = Some(read_u8(bytes)?);
223                    cursor += 1;
224                }
225                PropertyType::UserProperty => {
226                    let key = read_mqtt_string(bytes)?;
227                    let value = read_mqtt_string(bytes)?;
228                    cursor += 2 + key.len() + 2 + value.len();
229                    user_properties.push((key, value));
230                }
231                PropertyType::AuthenticationMethod => {
232                    let method = read_mqtt_string(bytes)?;
233                    cursor += 2 + method.len();
234                    authentication_method = Some(method);
235                }
236                PropertyType::AuthenticationData => {
237                    let data = read_mqtt_bytes(bytes)?;
238                    cursor += 2 + data.len();
239                    authentication_data = Some(data);
240                }
241                _ => return Err(Error::InvalidPropertyType(prop)),
242            }
243        }
244
245        Ok(Some(Self {
246            session_expiry_interval,
247            receive_maximum,
248            max_packet_size,
249            topic_alias_max,
250            request_response_info,
251            request_problem_info,
252            user_properties,
253            authentication_method,
254            authentication_data,
255        }))
256    }
257
258    fn len(&self) -> usize {
259        let mut len = 0;
260
261        if self.session_expiry_interval.is_some() {
262            len += 1 + 4;
263        }
264
265        if self.receive_maximum.is_some() {
266            len += 1 + 2;
267        }
268
269        if self.max_packet_size.is_some() {
270            len += 1 + 4;
271        }
272
273        if self.topic_alias_max.is_some() {
274            len += 1 + 2;
275        }
276
277        if self.request_response_info.is_some() {
278            len += 1 + 1;
279        }
280
281        if self.request_problem_info.is_some() {
282            len += 1 + 1;
283        }
284
285        for (key, value) in &self.user_properties {
286            len += 1 + 2 + key.len() + 2 + value.len();
287        }
288
289        if let Some(authentication_method) = &self.authentication_method {
290            len += 1 + 2 + authentication_method.len();
291        }
292
293        if let Some(authentication_data) = &self.authentication_data {
294            len += 1 + 2 + authentication_data.len();
295        }
296
297        len
298    }
299
300    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
301        let len = self.len();
302        write_remaining_length(buffer, len)?;
303
304        if let Some(session_expiry_interval) = self.session_expiry_interval {
305            buffer.put_u8(PropertyType::SessionExpiryInterval as u8);
306            buffer.put_u32(session_expiry_interval);
307        }
308
309        if let Some(receive_maximum) = self.receive_maximum {
310            buffer.put_u8(PropertyType::ReceiveMaximum as u8);
311            buffer.put_u16(receive_maximum);
312        }
313
314        if let Some(max_packet_size) = self.max_packet_size {
315            buffer.put_u8(PropertyType::MaximumPacketSize as u8);
316            buffer.put_u32(max_packet_size);
317        }
318
319        if let Some(topic_alias_max) = self.topic_alias_max {
320            buffer.put_u8(PropertyType::TopicAliasMaximum as u8);
321            buffer.put_u16(topic_alias_max);
322        }
323
324        if let Some(request_response_info) = self.request_response_info {
325            buffer.put_u8(PropertyType::RequestResponseInformation as u8);
326            buffer.put_u8(request_response_info);
327        }
328
329        if let Some(request_problem_info) = self.request_problem_info {
330            buffer.put_u8(PropertyType::RequestProblemInformation as u8);
331            buffer.put_u8(request_problem_info);
332        }
333
334        for (key, value) in &self.user_properties {
335            buffer.put_u8(PropertyType::UserProperty as u8);
336            write_mqtt_string(buffer, key);
337            write_mqtt_string(buffer, value);
338        }
339
340        if let Some(authentication_method) = &self.authentication_method {
341            buffer.put_u8(PropertyType::AuthenticationMethod as u8);
342            write_mqtt_string(buffer, authentication_method);
343        }
344
345        if let Some(authentication_data) = &self.authentication_data {
346            buffer.put_u8(PropertyType::AuthenticationData as u8);
347            write_mqtt_bytes(buffer, authentication_data);
348        }
349
350        Ok(())
351    }
352}
353
354impl Default for ConnectProperties {
355    fn default() -> Self {
356        Self::new()
357    }
358}
359
360/// `LastWill` that broker forwards on behalf of the client
361#[derive(Debug, Clone, PartialEq, Eq)]
362pub struct LastWill {
363    pub topic: Bytes,
364    pub message: Bytes,
365    pub qos: QoS,
366    pub retain: bool,
367    pub properties: Option<LastWillProperties>,
368}
369
370impl LastWill {
371    pub fn new(
372        topic: impl Into<String>,
373        payload: impl Into<Vec<u8>>,
374        qos: QoS,
375        retain: bool,
376        properties: Option<LastWillProperties>,
377    ) -> Self {
378        let topic = Bytes::from(topic.into().into_bytes());
379        Self {
380            topic,
381            message: Bytes::from(payload.into()),
382            qos,
383            retain,
384            properties,
385        }
386    }
387
388    fn len(&self) -> usize {
389        let mut len = 0;
390
391        if let Some(p) = &self.properties {
392            let properties_len = p.len();
393            let properties_len_len = len_len(properties_len);
394            len += properties_len_len + properties_len;
395        } else {
396            // just 1 byte representing 0 len
397            len += 1;
398        }
399
400        len += 2 + self.topic.len() + 2 + self.message.len();
401        len
402    }
403
404    pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<Self>, Error> {
405        let o = match connect_flags & 0b100 {
406            0 if (connect_flags & 0b0011_1000) != 0 => {
407                return Err(Error::IncorrectPacketFormat);
408            }
409            0 => None,
410            _ => {
411                // Properties in variable header
412                let properties = LastWillProperties::read(bytes)?;
413
414                let will_topic = read_mqtt_bytes(bytes)?;
415                let will_message = read_mqtt_bytes(bytes)?;
416                let qos_num = (connect_flags & 0b11000) >> 3;
417                let will_qos = qos(qos_num).ok_or(Error::InvalidQoS(qos_num))?;
418                Some(Self {
419                    topic: will_topic,
420                    message: will_message,
421                    qos: will_qos,
422                    retain: (connect_flags & 0b0010_0000) != 0,
423                    properties,
424                })
425            }
426        };
427
428        Ok(o)
429    }
430
431    pub fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
432        let mut connect_flags = 0;
433
434        connect_flags |= 0x04 | ((self.qos as u8) << 3);
435        if self.retain {
436            connect_flags |= 0x20;
437        }
438
439        if let Some(p) = &self.properties {
440            p.write(buffer)?;
441        } else {
442            write_remaining_length(buffer, 0)?;
443        }
444
445        write_mqtt_bytes(buffer, &self.topic);
446        write_mqtt_bytes(buffer, &self.message);
447        Ok(connect_flags)
448    }
449}
450
451#[derive(Debug, Clone, PartialEq, Eq)]
452pub struct LastWillProperties {
453    pub delay_interval: Option<u32>,
454    pub payload_format_indicator: Option<u8>,
455    pub message_expiry_interval: Option<u32>,
456    pub content_type: Option<String>,
457    pub response_topic: Option<String>,
458    pub correlation_data: Option<Bytes>,
459    pub user_properties: Vec<(String, String)>,
460}
461
462impl LastWillProperties {
463    fn len(&self) -> usize {
464        let mut len = 0;
465
466        if self.delay_interval.is_some() {
467            len += 1 + 4;
468        }
469
470        if self.payload_format_indicator.is_some() {
471            len += 1 + 1;
472        }
473
474        if self.message_expiry_interval.is_some() {
475            len += 1 + 4;
476        }
477
478        if let Some(typ) = &self.content_type {
479            len += 1 + 2 + typ.len();
480        }
481
482        if let Some(topic) = &self.response_topic {
483            len += 1 + 2 + topic.len();
484        }
485
486        if let Some(data) = &self.correlation_data {
487            len += 1 + 2 + data.len();
488        }
489
490        for (key, value) in &self.user_properties {
491            len += 1 + 2 + key.len() + 2 + value.len();
492        }
493
494        len
495    }
496
497    pub fn read(bytes: &mut Bytes) -> Result<Option<Self>, Error> {
498        let mut delay_interval = None;
499        let mut payload_format_indicator = None;
500        let mut message_expiry_interval = None;
501        let mut content_type = None;
502        let mut response_topic = None;
503        let mut correlation_data = None;
504        let mut user_properties = Vec::new();
505
506        let (properties_len_len, properties_len) = length(bytes.iter())?;
507        bytes.advance(properties_len_len);
508        if properties_len == 0 {
509            return Ok(None);
510        }
511
512        let mut cursor = 0;
513        // read until cursor reaches property length. properties_len = 0 will skip this loop
514        while cursor < properties_len {
515            let prop = read_u8(bytes)?;
516            cursor += 1;
517
518            match property(prop)? {
519                PropertyType::WillDelayInterval => {
520                    delay_interval = Some(read_u32(bytes)?);
521                    cursor += 4;
522                }
523                PropertyType::PayloadFormatIndicator => {
524                    payload_format_indicator = Some(read_u8(bytes)?);
525                    cursor += 1;
526                }
527                PropertyType::MessageExpiryInterval => {
528                    message_expiry_interval = Some(read_u32(bytes)?);
529                    cursor += 4;
530                }
531                PropertyType::ContentType => {
532                    let typ = read_mqtt_string(bytes)?;
533                    cursor += 2 + typ.len();
534                    content_type = Some(typ);
535                }
536                PropertyType::ResponseTopic => {
537                    let topic = read_mqtt_string(bytes)?;
538                    cursor += 2 + topic.len();
539                    response_topic = Some(topic);
540                }
541                PropertyType::CorrelationData => {
542                    let data = read_mqtt_bytes(bytes)?;
543                    cursor += 2 + data.len();
544                    correlation_data = Some(data);
545                }
546                PropertyType::UserProperty => {
547                    let key = read_mqtt_string(bytes)?;
548                    let value = read_mqtt_string(bytes)?;
549                    cursor += 2 + key.len() + 2 + value.len();
550                    user_properties.push((key, value));
551                }
552                _ => return Err(Error::InvalidPropertyType(prop)),
553            }
554        }
555
556        Ok(Some(Self {
557            delay_interval,
558            payload_format_indicator,
559            message_expiry_interval,
560            content_type,
561            response_topic,
562            correlation_data,
563            user_properties,
564        }))
565    }
566
567    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
568        let len = self.len();
569        write_remaining_length(buffer, len)?;
570
571        if let Some(delay_interval) = self.delay_interval {
572            buffer.put_u8(PropertyType::WillDelayInterval as u8);
573            buffer.put_u32(delay_interval);
574        }
575
576        if let Some(payload_format_indicator) = self.payload_format_indicator {
577            buffer.put_u8(PropertyType::PayloadFormatIndicator as u8);
578            buffer.put_u8(payload_format_indicator);
579        }
580
581        if let Some(message_expiry_interval) = self.message_expiry_interval {
582            buffer.put_u8(PropertyType::MessageExpiryInterval as u8);
583            buffer.put_u32(message_expiry_interval);
584        }
585
586        if let Some(typ) = &self.content_type {
587            buffer.put_u8(PropertyType::ContentType as u8);
588            write_mqtt_string(buffer, typ);
589        }
590
591        if let Some(topic) = &self.response_topic {
592            buffer.put_u8(PropertyType::ResponseTopic as u8);
593            write_mqtt_string(buffer, topic);
594        }
595
596        if let Some(data) = &self.correlation_data {
597            buffer.put_u8(PropertyType::CorrelationData as u8);
598            write_mqtt_bytes(buffer, data);
599        }
600
601        for (key, value) in &self.user_properties {
602            buffer.put_u8(PropertyType::UserProperty as u8);
603            write_mqtt_string(buffer, key);
604            write_mqtt_string(buffer, value);
605        }
606
607        Ok(())
608    }
609}
610#[derive(Debug, Clone, PartialEq, Eq, Default)]
611pub enum ConnectAuth {
612    #[default]
613    None,
614    Username {
615        username: String,
616    },
617    Password {
618        password: Bytes,
619    },
620    UsernamePassword {
621        username: String,
622        password: Bytes,
623    },
624}
625
626impl ConnectAuth {
627    pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Self, Error> {
628        let username_flag = (connect_flags & 0b1000_0000) != 0;
629        let password_flag = (connect_flags & 0b0100_0000) != 0;
630
631        match (username_flag, password_flag) {
632            (false, false) => Ok(Self::None),
633            (true, false) => Ok(Self::Username {
634                username: read_mqtt_string(bytes)?,
635            }),
636            (false, true) => Ok(Self::Password {
637                password: read_mqtt_bytes(bytes)?,
638            }),
639            (true, true) => Ok(Self::UsernamePassword {
640                username: read_mqtt_string(bytes)?,
641                password: read_mqtt_bytes(bytes)?,
642            }),
643        }
644    }
645
646    const fn len(&self) -> usize {
647        match self {
648            Self::None => 0,
649            Self::Username { username } => 2 + username.len(),
650            Self::Password { password } => 2 + password.len(),
651            Self::UsernamePassword { username, password } => {
652                2 + username.len() + 2 + password.len()
653            }
654        }
655    }
656
657    pub fn write(&self, buffer: &mut BytesMut) -> u8 {
658        match self {
659            Self::None => 0,
660            Self::Username { username } => {
661                write_mqtt_string(buffer, username);
662                0x80
663            }
664            Self::Password { password } => {
665                write_mqtt_bytes(buffer, password.as_ref());
666                0x40
667            }
668            Self::UsernamePassword { username, password } => {
669                write_mqtt_string(buffer, username);
670                write_mqtt_bytes(buffer, password.as_ref());
671                0xC0
672            }
673        }
674    }
675}
676
677const fn validate_connect_flags(connect_flags: u8) -> Result<(), Error> {
678    if (connect_flags & 0x01) != 0 {
679        return Err(Error::IncorrectPacketFormat);
680    }
681
682    Ok(())
683}
684
685#[cfg(test)]
686mod test {
687    use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
688    use super::*;
689    use crate::mqttbytes::v5::parse_fixed_header;
690    use bytes::Bytes;
691    use bytes::BytesMut;
692    use pretty_assertions::assert_eq;
693
694    #[test]
695    fn length_calculation() {
696        let mut dummy_bytes = BytesMut::new();
697        let mut connect_props = ConnectProperties::new();
698        // Use user_properties to pad the size to exceed ~128 bytes to make the
699        // remaining_length field in the packet be 2 bytes long.
700        connect_props.user_properties = vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())];
701        let connect_pkt = Connect {
702            keep_alive: 5,
703            client_id: "client".into(),
704            clean_start: true,
705            properties: Some(connect_props),
706        };
707
708        let reported_size = connect_pkt
709            .write(&None, &ConnectAuth::None, &mut dummy_bytes)
710            .unwrap();
711        let size_from_bytes = dummy_bytes.len();
712
713        assert_eq!(reported_size, size_from_bytes);
714    }
715
716    #[test]
717    fn read_rejects_receive_maximum_zero() {
718        let mut bytes = Bytes::from_static(&[
719            0x03, // properties length
720            0x21, // ReceiveMaximum property
721            0x00, 0x00, // value = 0
722        ]);
723        let result = ConnectProperties::read(&mut bytes);
724
725        assert!(matches!(result, Err(Error::ProtocolError)));
726    }
727
728    #[test]
729    fn connect_roundtrips_binary_password() {
730        let connect_pkt = Connect {
731            keep_alive: 5,
732            client_id: "client".into(),
733            clean_start: true,
734            properties: None,
735        };
736        let login = ConnectAuth::UsernamePassword {
737            username: "binary".to_owned(),
738            password: Bytes::from_static(b"\x00\xffproto\0buf"),
739        };
740
741        let mut buf = BytesMut::new();
742        connect_pkt.write(&None, &login, &mut buf).unwrap();
743
744        let fixed_header = parse_fixed_header(buf.iter()).unwrap();
745        let connect_bytes = buf.split_to(fixed_header.frame_length()).freeze();
746        let (_, _, decoded_login) = Connect::read(fixed_header, connect_bytes).unwrap();
747
748        assert_eq!(
749            decoded_login,
750            ConnectAuth::UsernamePassword {
751                username: "binary".to_owned(),
752                password: Bytes::from_static(b"\x00\xffproto\0buf"),
753            }
754        );
755    }
756
757    #[test]
758    fn connect_encoding_with_password_and_empty_username_writes_zero_len_username() {
759        let connect_pkt = Connect {
760            keep_alive: 5,
761            client_id: "client".into(),
762            clean_start: true,
763            properties: None,
764        };
765        let login = ConnectAuth::UsernamePassword {
766            username: String::new(),
767            password: Bytes::from_static(b"pw"),
768        };
769
770        let mut buf = BytesMut::new();
771        connect_pkt.write(&None, &login, &mut buf).unwrap();
772
773        assert_eq!(buf[9], 0b1100_0010);
774
775        let fixed_header = parse_fixed_header(buf.iter()).unwrap();
776        let connect_bytes = buf.split_to(fixed_header.frame_length()).freeze();
777        let (_, _, decoded_login) = Connect::read(fixed_header, connect_bytes).unwrap();
778        assert_eq!(
779            decoded_login,
780            ConnectAuth::UsernamePassword {
781                username: String::new(),
782                password: Bytes::from_static(b"pw"),
783            }
784        );
785    }
786
787    #[test]
788    fn connect_parsing_accepts_password_without_username_flag() {
789        let mut stream = bytes::BytesMut::new();
790        let packetstream = &[
791            0x10, 0x15, 0x00, 0x04, b'M', b'Q', b'T', b'T', 0x05, 0x42, 0x00, 0x0a, 0x00, 0x00,
792            0x04, b't', b'e', b's', b't', 0x00, 0x02, 0xff, 0x00,
793        ];
794        stream.extend_from_slice(packetstream);
795        let fixed_header = parse_fixed_header(stream.iter()).unwrap();
796        let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze();
797        let (_, will, auth) = Connect::read(fixed_header, connect_bytes).unwrap();
798
799        assert_eq!(will, None);
800        assert_eq!(
801            auth,
802            ConnectAuth::Password {
803                password: Bytes::from_static(b"\xff\0"),
804            }
805        );
806    }
807
808    #[test]
809    fn connect_parsing_rejects_reserved_connect_flag_bit() {
810        let mut stream = bytes::BytesMut::new();
811        let packetstream = &[
812            0x10, 0x10, 0x00, 0x04, b'M', b'Q', b'T', b'T', 0x05, 0x03, 0x00, 0x0a, 0x00, 0x00,
813            0x00, 0x04, b't', b'e', b's', b't',
814        ];
815        stream.extend_from_slice(packetstream);
816        let fixed_header = parse_fixed_header(stream.iter()).unwrap();
817        let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze();
818        let packet = Connect::read(fixed_header, connect_bytes);
819
820        assert!(matches!(packet, Err(Error::IncorrectPacketFormat)));
821    }
822
823    #[test]
824    fn connect_roundtrips_explicitly_empty_password() {
825        let connect_pkt = Connect {
826            keep_alive: 5,
827            client_id: "client".into(),
828            clean_start: true,
829            properties: None,
830        };
831        let auth = ConnectAuth::UsernamePassword {
832            username: "user".to_owned(),
833            password: Bytes::new(),
834        };
835
836        let mut buf = BytesMut::new();
837        connect_pkt.write(&None, &auth, &mut buf).unwrap();
838
839        assert_eq!(buf[9], 0b1100_0010);
840
841        let fixed_header = parse_fixed_header(buf.iter()).unwrap();
842        let connect_bytes = buf.split_to(fixed_header.frame_length()).freeze();
843        let (_, _, decoded_auth) = Connect::read(fixed_header, connect_bytes).unwrap();
844        assert_eq!(
845            decoded_auth,
846            ConnectAuth::UsernamePassword {
847                username: "user".to_owned(),
848                password: Bytes::new(),
849            }
850        );
851    }
852
853    #[test]
854    fn connect_roundtrips_password_only_auth() {
855        let connect_pkt = Connect {
856            keep_alive: 5,
857            client_id: "client".into(),
858            clean_start: true,
859            properties: None,
860        };
861        let auth = ConnectAuth::Password {
862            password: Bytes::from_static(b"\x00\xffproto\0buf"),
863        };
864
865        let mut buf = BytesMut::new();
866        connect_pkt.write(&None, &auth, &mut buf).unwrap();
867
868        assert_eq!(buf[9], 0b0100_0010);
869
870        let fixed_header = parse_fixed_header(buf.iter()).unwrap();
871        let connect_bytes = buf.split_to(fixed_header.frame_length()).freeze();
872        let (_, _, decoded_auth) = Connect::read(fixed_header, connect_bytes).unwrap();
873        assert_eq!(
874            decoded_auth,
875            ConnectAuth::Password {
876                password: Bytes::from_static(b"\x00\xffproto\0buf"),
877            }
878        );
879    }
880}