mqtt5_protocol/packet/
connect.rs

1use crate::encoding::{decode_binary, decode_string, encode_binary, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::ConnectFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::prelude::{format, String, ToString, Vec};
6use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
7use crate::types::{ConnectOptions, WillMessage, WillProperties};
8use crate::QoS;
9use bytes::{Buf, BufMut, Bytes};
10
11const PROTOCOL_NAME: &str = "MQTT";
12const PROTOCOL_VERSION_V5: u8 = 5;
13const PROTOCOL_VERSION_V311: u8 = 4;
14
15/// MQTT CONNECT packet
16#[derive(Debug, Clone)]
17pub struct ConnectPacket {
18    /// Protocol version (4 for v3.1.1, 5 for v5.0)
19    pub protocol_version: u8,
20    /// Clean start flag (Clean Session in v3.1.1)
21    pub clean_start: bool,
22    /// Keep alive interval in seconds
23    pub keep_alive: u16,
24    /// Client identifier
25    pub client_id: String,
26    /// Username (optional)
27    pub username: Option<String>,
28    /// Password (optional)
29    pub password: Option<Vec<u8>>,
30    /// Will message (optional)
31    pub will: Option<WillMessage>,
32    /// CONNECT properties (v5.0 only)
33    pub properties: Properties,
34    /// Will properties (v5.0 only)
35    pub will_properties: Properties,
36}
37
38impl ConnectPacket {
39    /// Creates a new CONNECT packet from options
40    #[must_use]
41    pub fn new(options: ConnectOptions) -> Self {
42        let properties = Self::build_connect_properties(&options.properties);
43        let will_properties = options
44            .will
45            .as_ref()
46            .map_or_else(Properties::default, |will| {
47                Self::build_will_properties(&will.properties)
48            });
49
50        Self {
51            protocol_version: PROTOCOL_VERSION_V5,
52            clean_start: options.clean_start,
53            keep_alive: Self::calculate_keep_alive(options.keep_alive),
54            client_id: options.client_id,
55            username: options.username,
56            password: options.password,
57            will: options.will,
58            properties,
59            will_properties,
60        }
61    }
62
63    /// Builds CONNECT properties from options
64    fn build_connect_properties(props: &crate::types::ConnectProperties) -> Properties {
65        let mut properties = Properties::default();
66
67        if let Some(val) = props.session_expiry_interval {
68            let _ = properties.add(
69                PropertyId::SessionExpiryInterval,
70                PropertyValue::FourByteInteger(val),
71            );
72        }
73        if let Some(val) = props.receive_maximum {
74            let _ = properties.add(
75                PropertyId::ReceiveMaximum,
76                PropertyValue::TwoByteInteger(val),
77            );
78        }
79        if let Some(val) = props.maximum_packet_size {
80            let _ = properties.add(
81                PropertyId::MaximumPacketSize,
82                PropertyValue::FourByteInteger(val),
83            );
84        }
85        if let Some(val) = props.topic_alias_maximum {
86            let _ = properties.add(
87                PropertyId::TopicAliasMaximum,
88                PropertyValue::TwoByteInteger(val),
89            );
90        }
91        if let Some(val) = props.request_response_information {
92            let _ = properties.add(
93                PropertyId::RequestResponseInformation,
94                PropertyValue::Byte(u8::from(val)),
95            );
96        }
97        if let Some(val) = props.request_problem_information {
98            let _ = properties.add(
99                PropertyId::RequestProblemInformation,
100                PropertyValue::Byte(u8::from(val)),
101            );
102        }
103        if let Some(val) = &props.authentication_method {
104            let _ = properties.add(
105                PropertyId::AuthenticationMethod,
106                PropertyValue::Utf8String(val.clone()),
107            );
108        }
109        if let Some(val) = &props.authentication_data {
110            let _ = properties.add(
111                PropertyId::AuthenticationData,
112                PropertyValue::BinaryData(val.clone().into()),
113            );
114        }
115        for (key, value) in &props.user_properties {
116            let _ = properties.add(
117                PropertyId::UserProperty,
118                PropertyValue::Utf8StringPair(key.clone(), value.clone()),
119            );
120        }
121
122        properties
123    }
124
125    /// Builds will properties from will options
126    fn build_will_properties(will_props: &crate::types::WillProperties) -> Properties {
127        let mut properties = Properties::default();
128
129        if let Some(val) = will_props.will_delay_interval {
130            let _ = properties.add(
131                PropertyId::WillDelayInterval,
132                PropertyValue::FourByteInteger(val),
133            );
134        }
135        if let Some(val) = will_props.payload_format_indicator {
136            let _ = properties.add(
137                PropertyId::PayloadFormatIndicator,
138                PropertyValue::Byte(u8::from(val)),
139            );
140        }
141        if let Some(val) = will_props.message_expiry_interval {
142            let _ = properties.add(
143                PropertyId::MessageExpiryInterval,
144                PropertyValue::FourByteInteger(val),
145            );
146        }
147        if let Some(val) = &will_props.content_type {
148            let _ = properties.add(
149                PropertyId::ContentType,
150                PropertyValue::Utf8String(val.clone()),
151            );
152        }
153        if let Some(val) = &will_props.response_topic {
154            let _ = properties.add(
155                PropertyId::ResponseTopic,
156                PropertyValue::Utf8String(val.clone()),
157            );
158        }
159        if let Some(val) = &will_props.correlation_data {
160            let _ = properties.add(
161                PropertyId::CorrelationData,
162                PropertyValue::BinaryData(val.clone().into()),
163            );
164        }
165        for (key, value) in &will_props.user_properties {
166            let _ = properties.add(
167                PropertyId::UserProperty,
168                PropertyValue::Utf8StringPair(key.clone(), value.clone()),
169            );
170        }
171
172        properties
173    }
174
175    /// Calculates keep alive value, clamping to u16 range
176    fn calculate_keep_alive(keep_alive: crate::time::Duration) -> u16 {
177        keep_alive
178            .as_secs()
179            .min(u64::from(u16::MAX))
180            .try_into()
181            .unwrap_or(u16::MAX)
182    }
183
184    /// Creates a v3.1.1 compatible CONNECT packet
185    #[must_use]
186    pub fn new_v311(options: ConnectOptions) -> Self {
187        Self {
188            protocol_version: PROTOCOL_VERSION_V311,
189            clean_start: options.clean_start,
190            keep_alive: Self::calculate_keep_alive(options.keep_alive),
191            client_id: options.client_id,
192            username: options.username,
193            password: options.password,
194            will: options.will,
195            properties: Properties::default(),
196            will_properties: Properties::default(),
197        }
198    }
199
200    /// Creates connect flags byte
201    fn connect_flags(&self) -> u8 {
202        let mut flags = 0u8;
203
204        if self.clean_start {
205            flags |= ConnectFlags::CleanStart as u8;
206        }
207
208        if let Some(ref will) = self.will {
209            flags |= ConnectFlags::WillFlag as u8;
210            flags = ConnectFlags::with_will_qos(flags, will.qos as u8);
211            if will.retain {
212                flags |= ConnectFlags::WillRetain as u8;
213            }
214        }
215
216        if self.username.is_some() {
217            flags |= ConnectFlags::UsernameFlag as u8;
218        }
219
220        if self.password.is_some() {
221            flags |= ConnectFlags::PasswordFlag as u8;
222        }
223
224        flags
225    }
226}
227
228impl MqttPacket for ConnectPacket {
229    fn packet_type(&self) -> PacketType {
230        PacketType::Connect
231    }
232
233    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
234        // Variable header
235        encode_string(buf, PROTOCOL_NAME)?;
236        buf.put_u8(self.protocol_version);
237        buf.put_u8(self.connect_flags());
238        buf.put_u16(self.keep_alive);
239
240        // Properties (v5.0 only)
241        if self.protocol_version == PROTOCOL_VERSION_V5 {
242            self.properties.encode(buf)?;
243        }
244
245        // Payload
246        encode_string(buf, &self.client_id)?;
247
248        // Will
249        if let Some(ref will) = self.will {
250            if self.protocol_version == PROTOCOL_VERSION_V5 {
251                self.will_properties.encode(buf)?;
252            }
253            encode_string(buf, &will.topic)?;
254            encode_binary(buf, &will.payload)?;
255        }
256
257        // Username
258        if let Some(ref username) = self.username {
259            encode_string(buf, username)?;
260        }
261
262        // Password
263        if let Some(ref password) = self.password {
264            encode_binary(buf, password)?;
265        }
266
267        Ok(())
268    }
269
270    fn decode_body<B: Buf>(buf: &mut B, _fixed_header: &FixedHeader) -> Result<Self> {
271        // Decode variable header
272        let protocol_version = Self::decode_protocol_header(buf)?;
273        let (flags, keep_alive) = Self::decode_connect_flags_and_keepalive(buf)?;
274
275        // Properties (v5.0 only)
276        let properties = if protocol_version == PROTOCOL_VERSION_V5 {
277            Properties::decode(buf)?
278        } else {
279            Properties::default()
280        };
281
282        // Decode payload
283        let client_id = decode_string(buf)?;
284        let (will, will_properties) = Self::decode_will(buf, &flags, protocol_version)?;
285        let (username, password) = Self::decode_credentials(buf, &flags)?;
286
287        Ok(Self {
288            protocol_version,
289            clean_start: flags.clean_start,
290            keep_alive,
291            client_id,
292            username,
293            password: password.map(|p| p.to_vec()),
294            will,
295            properties,
296            will_properties,
297        })
298    }
299}
300
301/// Helper struct to hold decoded connect flags
302struct DecodedConnectFlags {
303    clean_start: bool,
304    will_flag: bool,
305    will_qos: u8,
306    will_retain: bool,
307    credentials: CredentialFlags,
308}
309
310struct CredentialFlags {
311    username_flag: bool,
312    password_flag: bool,
313}
314
315impl ConnectPacket {
316    /// Decode and validate protocol header
317    fn decode_protocol_header<B: Buf>(buf: &mut B) -> Result<u8> {
318        // Protocol name
319        let protocol_name = decode_string(buf)?;
320        if protocol_name != PROTOCOL_NAME {
321            return Err(MqttError::ProtocolError(format!(
322                "Invalid protocol name: {protocol_name}"
323            )));
324        }
325
326        // Protocol version
327        if !buf.has_remaining() {
328            return Err(MqttError::MalformedPacket(
329                "Missing protocol version".to_string(),
330            ));
331        }
332        let protocol_version = buf.get_u8();
333
334        Ok(protocol_version)
335    }
336
337    /// Decode connect flags and keep alive
338    fn decode_connect_flags_and_keepalive<B: Buf>(
339        buf: &mut B,
340    ) -> Result<(DecodedConnectFlags, u16)> {
341        // Connect flags
342        if !buf.has_remaining() {
343            return Err(MqttError::MalformedPacket(
344                "Missing connect flags".to_string(),
345            ));
346        }
347        let flags = buf.get_u8();
348
349        // Parse flags using BeBytes decomposition
350        let decomposed_flags = ConnectFlags::decompose(flags);
351
352        // Validate reserved bit
353        if decomposed_flags.contains(&ConnectFlags::Reserved) {
354            return Err(MqttError::MalformedPacket(
355                "Reserved flag bit must be 0".to_string(),
356            ));
357        }
358
359        let credentials = CredentialFlags {
360            username_flag: decomposed_flags.contains(&ConnectFlags::UsernameFlag),
361            password_flag: decomposed_flags.contains(&ConnectFlags::PasswordFlag),
362        };
363
364        let decoded_flags = DecodedConnectFlags {
365            clean_start: decomposed_flags.contains(&ConnectFlags::CleanStart),
366            will_flag: decomposed_flags.contains(&ConnectFlags::WillFlag),
367            will_qos: ConnectFlags::extract_will_qos(flags),
368            will_retain: decomposed_flags.contains(&ConnectFlags::WillRetain),
369            credentials,
370        };
371
372        // Keep alive
373        if buf.remaining() < 2 {
374            return Err(MqttError::MalformedPacket("Missing keep alive".to_string()));
375        }
376        let keep_alive = buf.get_u16();
377
378        Ok((decoded_flags, keep_alive))
379    }
380
381    /// Decode will message if present
382    fn decode_will<B: Buf>(
383        buf: &mut B,
384        flags: &DecodedConnectFlags,
385        protocol_version: u8,
386    ) -> Result<(Option<WillMessage>, Properties)> {
387        if !flags.will_flag {
388            return Ok((None, Properties::default()));
389        }
390
391        let will_properties = if protocol_version == PROTOCOL_VERSION_V5 {
392            Properties::decode(buf)?
393        } else {
394            Properties::default()
395        };
396
397        let topic = decode_string(buf)?;
398        let payload = decode_binary(buf)?;
399
400        let qos = match flags.will_qos {
401            0 => QoS::AtMostOnce,
402            1 => QoS::AtLeastOnce,
403            2 => QoS::ExactlyOnce,
404            _ => return Err(MqttError::MalformedPacket("Invalid will QoS".to_string())),
405        };
406
407        // Convert Properties to WillProperties
408        let will_props = Self::properties_to_will_properties(&will_properties);
409
410        let will = WillMessage {
411            topic,
412            payload: payload.to_vec(),
413            qos,
414            retain: flags.will_retain,
415            properties: will_props,
416        };
417
418        Ok((Some(will), will_properties))
419    }
420
421    /// Convert Properties to `WillProperties`
422    fn properties_to_will_properties(props: &Properties) -> WillProperties {
423        use crate::protocol::v5::properties::{PropertyId, PropertyValue};
424
425        let mut will_props = WillProperties::default();
426
427        // Extract will delay interval
428        if let Some(PropertyValue::FourByteInteger(delay)) =
429            props.get(PropertyId::WillDelayInterval)
430        {
431            will_props.will_delay_interval = Some(*delay);
432        }
433
434        // Extract payload format indicator
435        if let Some(PropertyValue::Byte(indicator)) = props.get(PropertyId::PayloadFormatIndicator)
436        {
437            will_props.payload_format_indicator = Some(*indicator != 0);
438        }
439
440        // Extract message expiry interval
441        if let Some(PropertyValue::FourByteInteger(expiry)) =
442            props.get(PropertyId::MessageExpiryInterval)
443        {
444            will_props.message_expiry_interval = Some(*expiry);
445        }
446
447        // Extract content type
448        if let Some(PropertyValue::Utf8String(content_type)) = props.get(PropertyId::ContentType) {
449            will_props.content_type = Some(content_type.clone());
450        }
451
452        // Extract response topic
453        if let Some(PropertyValue::Utf8String(topic)) = props.get(PropertyId::ResponseTopic) {
454            will_props.response_topic = Some(topic.clone());
455        }
456
457        // Extract correlation data
458        if let Some(PropertyValue::BinaryData(data)) = props.get(PropertyId::CorrelationData) {
459            will_props.correlation_data = Some(data.to_vec());
460        }
461
462        // Extract user properties
463        if let Some(values) = props.get_all(PropertyId::UserProperty) {
464            for value in values {
465                if let PropertyValue::Utf8StringPair(key, val) = value {
466                    will_props.user_properties.push((key.clone(), val.clone()));
467                }
468            }
469        }
470
471        will_props
472    }
473
474    /// Decode username and password if present
475    fn decode_credentials<B: Buf>(
476        buf: &mut B,
477        flags: &DecodedConnectFlags,
478    ) -> Result<(Option<String>, Option<Bytes>)> {
479        let username = if flags.credentials.username_flag {
480            Some(decode_string(buf)?)
481        } else {
482            None
483        };
484
485        let password = if flags.credentials.password_flag {
486            Some(decode_binary(buf)?)
487        } else {
488            None
489        };
490
491        // Validate password without username
492        if password.is_some() && username.is_none() {
493            return Err(MqttError::MalformedPacket(
494                "Password without username is not allowed".to_string(),
495            ));
496        }
497
498        Ok((username, password))
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use crate::time::Duration;
506    use bytes::BytesMut;
507
508    #[test]
509    fn test_connect_packet_basic() {
510        let options = ConnectOptions::new("test-client");
511        let packet = ConnectPacket::new(options);
512
513        assert_eq!(packet.protocol_version, PROTOCOL_VERSION_V5);
514        assert!(packet.clean_start);
515        assert_eq!(packet.keep_alive, 60);
516        assert_eq!(packet.client_id, "test-client");
517        assert!(packet.username.is_none());
518        assert!(packet.password.is_none());
519        assert!(packet.will.is_none());
520    }
521
522    #[test]
523    fn test_connect_packet_with_credentials() {
524        let options = ConnectOptions::new("test-client").with_credentials("user", b"pass");
525        let packet = ConnectPacket::new(options);
526
527        assert_eq!(packet.username, Some("user".to_string()));
528        assert_eq!(packet.password, Some(b"pass".to_vec()));
529    }
530
531    #[test]
532    fn test_connect_packet_with_will() {
533        let will = WillMessage::new("will/topic", b"will payload")
534            .with_qos(QoS::AtLeastOnce)
535            .with_retain(true);
536        let options = ConnectOptions::new("test-client").with_will(will);
537        let packet = ConnectPacket::new(options);
538
539        assert!(packet.will.is_some());
540        let will = packet.will.as_ref().unwrap();
541        assert_eq!(will.topic, "will/topic");
542        assert_eq!(will.payload, b"will payload");
543        assert_eq!(will.qos, QoS::AtLeastOnce);
544        assert!(will.retain);
545    }
546
547    #[test]
548    fn test_connect_flags() {
549        let packet = ConnectPacket::new(ConnectOptions::new("test"));
550        assert_eq!(packet.connect_flags(), 0x02); // Clean start only
551
552        let options = ConnectOptions::new("test")
553            .with_clean_start(false)
554            .with_credentials("user", b"pass");
555        let packet = ConnectPacket::new(options);
556        assert_eq!(packet.connect_flags(), 0xC0); // Username + Password
557
558        let will = WillMessage::new("topic", b"payload")
559            .with_qos(QoS::ExactlyOnce)
560            .with_retain(true);
561        let options = ConnectOptions::new("test").with_will(will);
562        let packet = ConnectPacket::new(options);
563        assert_eq!(packet.connect_flags(), 0x36); // Clean start + Will + QoS 2 + Retain
564    }
565
566    #[test]
567    fn test_connect_encode_decode_v5() {
568        let options = ConnectOptions::new("test-client-123")
569            .with_keep_alive(Duration::from_secs(120))
570            .with_credentials("testuser", b"testpass");
571        let packet = ConnectPacket::new(options);
572
573        let mut buf = BytesMut::new();
574        packet.encode(&mut buf).unwrap();
575
576        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
577        assert_eq!(fixed_header.packet_type, PacketType::Connect);
578
579        let decoded = ConnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
580        assert_eq!(decoded.protocol_version, PROTOCOL_VERSION_V5);
581        assert_eq!(decoded.client_id, "test-client-123");
582        assert_eq!(decoded.keep_alive, 120);
583        assert_eq!(decoded.username, Some("testuser".to_string()));
584        assert_eq!(decoded.password, Some(b"testpass".to_vec()));
585    }
586
587    #[test]
588    fn test_connect_encode_decode_v311() {
589        let options = ConnectOptions::new("mqtt-311-client");
590        let packet = ConnectPacket::new_v311(options);
591
592        let mut buf = BytesMut::new();
593        packet.encode(&mut buf).unwrap();
594
595        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
596        let decoded = ConnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
597
598        assert_eq!(decoded.protocol_version, PROTOCOL_VERSION_V311);
599        assert_eq!(decoded.client_id, "mqtt-311-client");
600    }
601
602    #[test]
603    fn test_connect_invalid_protocol_name() {
604        let mut buf = BytesMut::new();
605        encode_string(&mut buf, "INVALID").unwrap();
606        buf.put_u8(5);
607
608        let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
609        let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
610        assert!(result.is_err());
611    }
612
613    #[test]
614    fn test_connect_invalid_protocol_version() {
615        let mut buf = BytesMut::new();
616        encode_string(&mut buf, "MQTT").unwrap();
617        buf.put_u8(99); // Invalid version
618
619        let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
620        let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
621        assert!(result.is_err());
622    }
623
624    #[test]
625    fn test_connect_password_without_username() {
626        let mut buf = BytesMut::new();
627        encode_string(&mut buf, "MQTT").unwrap();
628        buf.put_u8(5); // v5.0
629        buf.put_u8(0x40); // Password flag only
630        buf.put_u16(60); // Keep alive
631        buf.put_u8(0); // Empty properties
632        encode_string(&mut buf, "client").unwrap();
633        encode_binary(&mut buf, b"password").unwrap();
634
635        let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
636        let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
637        assert!(result.is_err());
638    }
639}