mqtt5_protocol/packet/
connect.rs

1use crate::encoding::{decode_binary, decode_string, encode_binary, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::ConnectFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
6use crate::types::{ConnectOptions, WillMessage, WillProperties};
7use crate::QoS;
8use bytes::{Buf, BufMut, Bytes};
9
10const PROTOCOL_NAME: &str = "MQTT";
11const PROTOCOL_VERSION_V5: u8 = 5;
12const PROTOCOL_VERSION_V311: u8 = 4;
13
14/// MQTT CONNECT packet
15#[derive(Debug, Clone)]
16pub struct ConnectPacket {
17    /// Protocol version (4 for v3.1.1, 5 for v5.0)
18    pub protocol_version: u8,
19    /// Clean start flag (Clean Session in v3.1.1)
20    pub clean_start: bool,
21    /// Keep alive interval in seconds
22    pub keep_alive: u16,
23    /// Client identifier
24    pub client_id: String,
25    /// Username (optional)
26    pub username: Option<String>,
27    /// Password (optional)
28    pub password: Option<Vec<u8>>,
29    /// Will message (optional)
30    pub will: Option<WillMessage>,
31    /// CONNECT properties (v5.0 only)
32    pub properties: Properties,
33    /// Will properties (v5.0 only)
34    pub will_properties: Properties,
35}
36
37impl ConnectPacket {
38    /// Creates a new CONNECT packet from options
39    #[must_use]
40    pub fn new(options: ConnectOptions) -> Self {
41        let properties = Self::build_connect_properties(&options.properties);
42        let will_properties = options
43            .will
44            .as_ref()
45            .map_or_else(Properties::default, |will| {
46                Self::build_will_properties(&will.properties)
47            });
48
49        Self {
50            protocol_version: PROTOCOL_VERSION_V5,
51            clean_start: options.clean_start,
52            keep_alive: Self::calculate_keep_alive(options.keep_alive),
53            client_id: options.client_id,
54            username: options.username,
55            password: options.password,
56            will: options.will,
57            properties,
58            will_properties,
59        }
60    }
61
62    /// Builds CONNECT properties from options
63    fn build_connect_properties(props: &crate::types::ConnectProperties) -> Properties {
64        let mut properties = Properties::default();
65
66        if let Some(val) = props.session_expiry_interval {
67            let _ = properties.add(
68                PropertyId::SessionExpiryInterval,
69                PropertyValue::FourByteInteger(val),
70            );
71        }
72        if let Some(val) = props.receive_maximum {
73            let _ = properties.add(
74                PropertyId::ReceiveMaximum,
75                PropertyValue::TwoByteInteger(val),
76            );
77        }
78        if let Some(val) = props.maximum_packet_size {
79            let _ = properties.add(
80                PropertyId::MaximumPacketSize,
81                PropertyValue::FourByteInteger(val),
82            );
83        }
84        if let Some(val) = props.topic_alias_maximum {
85            let _ = properties.add(
86                PropertyId::TopicAliasMaximum,
87                PropertyValue::TwoByteInteger(val),
88            );
89        }
90        if let Some(val) = props.request_response_information {
91            let _ = properties.add(
92                PropertyId::RequestResponseInformation,
93                PropertyValue::Byte(u8::from(val)),
94            );
95        }
96        if let Some(val) = props.request_problem_information {
97            let _ = properties.add(
98                PropertyId::RequestProblemInformation,
99                PropertyValue::Byte(u8::from(val)),
100            );
101        }
102        if let Some(val) = &props.authentication_method {
103            let _ = properties.add(
104                PropertyId::AuthenticationMethod,
105                PropertyValue::Utf8String(val.clone()),
106            );
107        }
108        if let Some(val) = &props.authentication_data {
109            let _ = properties.add(
110                PropertyId::AuthenticationData,
111                PropertyValue::BinaryData(val.clone().into()),
112            );
113        }
114        for (key, value) in &props.user_properties {
115            let _ = properties.add(
116                PropertyId::UserProperty,
117                PropertyValue::Utf8StringPair(key.clone(), value.clone()),
118            );
119        }
120
121        properties
122    }
123
124    /// Builds will properties from will options
125    fn build_will_properties(will_props: &crate::types::WillProperties) -> Properties {
126        let mut properties = Properties::default();
127
128        if let Some(val) = will_props.will_delay_interval {
129            let _ = properties.add(
130                PropertyId::WillDelayInterval,
131                PropertyValue::FourByteInteger(val),
132            );
133        }
134        if let Some(val) = will_props.payload_format_indicator {
135            let _ = properties.add(
136                PropertyId::PayloadFormatIndicator,
137                PropertyValue::Byte(u8::from(val)),
138            );
139        }
140        if let Some(val) = will_props.message_expiry_interval {
141            let _ = properties.add(
142                PropertyId::MessageExpiryInterval,
143                PropertyValue::FourByteInteger(val),
144            );
145        }
146        if let Some(val) = &will_props.content_type {
147            let _ = properties.add(
148                PropertyId::ContentType,
149                PropertyValue::Utf8String(val.clone()),
150            );
151        }
152        if let Some(val) = &will_props.response_topic {
153            let _ = properties.add(
154                PropertyId::ResponseTopic,
155                PropertyValue::Utf8String(val.clone()),
156            );
157        }
158        if let Some(val) = &will_props.correlation_data {
159            let _ = properties.add(
160                PropertyId::CorrelationData,
161                PropertyValue::BinaryData(val.clone().into()),
162            );
163        }
164        for (key, value) in &will_props.user_properties {
165            let _ = properties.add(
166                PropertyId::UserProperty,
167                PropertyValue::Utf8StringPair(key.clone(), value.clone()),
168            );
169        }
170
171        properties
172    }
173
174    /// Calculates keep alive value, clamping to u16 range
175    fn calculate_keep_alive(keep_alive: crate::time::Duration) -> u16 {
176        keep_alive
177            .as_secs()
178            .min(u64::from(u16::MAX))
179            .try_into()
180            .unwrap_or(u16::MAX)
181    }
182
183    /// Creates a v3.1.1 compatible CONNECT packet
184    #[must_use]
185    pub fn new_v311(options: ConnectOptions) -> Self {
186        Self {
187            protocol_version: PROTOCOL_VERSION_V311,
188            clean_start: options.clean_start,
189            keep_alive: Self::calculate_keep_alive(options.keep_alive),
190            client_id: options.client_id,
191            username: options.username,
192            password: options.password,
193            will: options.will,
194            properties: Properties::default(),
195            will_properties: Properties::default(),
196        }
197    }
198
199    /// Creates connect flags byte
200    fn connect_flags(&self) -> u8 {
201        let mut flags = 0u8;
202
203        if self.clean_start {
204            flags |= ConnectFlags::CleanStart as u8;
205        }
206
207        if let Some(ref will) = self.will {
208            flags |= ConnectFlags::WillFlag as u8;
209            flags = ConnectFlags::with_will_qos(flags, will.qos as u8);
210            if will.retain {
211                flags |= ConnectFlags::WillRetain as u8;
212            }
213        }
214
215        if self.username.is_some() {
216            flags |= ConnectFlags::UsernameFlag as u8;
217        }
218
219        if self.password.is_some() {
220            flags |= ConnectFlags::PasswordFlag as u8;
221        }
222
223        flags
224    }
225}
226
227impl MqttPacket for ConnectPacket {
228    fn packet_type(&self) -> PacketType {
229        PacketType::Connect
230    }
231
232    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
233        // Variable header
234        encode_string(buf, PROTOCOL_NAME)?;
235        buf.put_u8(self.protocol_version);
236        buf.put_u8(self.connect_flags());
237        buf.put_u16(self.keep_alive);
238
239        // Properties (v5.0 only)
240        if self.protocol_version == PROTOCOL_VERSION_V5 {
241            self.properties.encode(buf)?;
242        }
243
244        // Payload
245        encode_string(buf, &self.client_id)?;
246
247        // Will
248        if let Some(ref will) = self.will {
249            if self.protocol_version == PROTOCOL_VERSION_V5 {
250                self.will_properties.encode(buf)?;
251            }
252            encode_string(buf, &will.topic)?;
253            encode_binary(buf, &will.payload)?;
254        }
255
256        // Username
257        if let Some(ref username) = self.username {
258            encode_string(buf, username)?;
259        }
260
261        // Password
262        if let Some(ref password) = self.password {
263            encode_binary(buf, password)?;
264        }
265
266        Ok(())
267    }
268
269    fn decode_body<B: Buf>(buf: &mut B, _fixed_header: &FixedHeader) -> Result<Self> {
270        // Decode variable header
271        let protocol_version = Self::decode_protocol_header(buf)?;
272        let (flags, keep_alive) = Self::decode_connect_flags_and_keepalive(buf)?;
273
274        // Properties (v5.0 only)
275        let properties = if protocol_version == PROTOCOL_VERSION_V5 {
276            Properties::decode(buf)?
277        } else {
278            Properties::default()
279        };
280
281        // Decode payload
282        let client_id = decode_string(buf)?;
283        let (will, will_properties) = Self::decode_will(buf, &flags, protocol_version)?;
284        let (username, password) = Self::decode_credentials(buf, &flags)?;
285
286        Ok(Self {
287            protocol_version,
288            clean_start: flags.clean_start,
289            keep_alive,
290            client_id,
291            username,
292            password: password.map(|p| p.to_vec()),
293            will,
294            properties,
295            will_properties,
296        })
297    }
298}
299
300/// Helper struct to hold decoded connect flags
301struct DecodedConnectFlags {
302    clean_start: bool,
303    will_flag: bool,
304    will_qos: u8,
305    will_retain: bool,
306    credentials: CredentialFlags,
307}
308
309struct CredentialFlags {
310    username_flag: bool,
311    password_flag: bool,
312}
313
314impl ConnectPacket {
315    /// Decode and validate protocol header
316    fn decode_protocol_header<B: Buf>(buf: &mut B) -> Result<u8> {
317        // Protocol name
318        let protocol_name = decode_string(buf)?;
319        if protocol_name != PROTOCOL_NAME {
320            return Err(MqttError::ProtocolError(format!(
321                "Invalid protocol name: {protocol_name}"
322            )));
323        }
324
325        // Protocol version
326        if !buf.has_remaining() {
327            return Err(MqttError::MalformedPacket(
328                "Missing protocol version".to_string(),
329            ));
330        }
331        let protocol_version = buf.get_u8();
332
333        Ok(protocol_version)
334    }
335
336    /// Decode connect flags and keep alive
337    fn decode_connect_flags_and_keepalive<B: Buf>(
338        buf: &mut B,
339    ) -> Result<(DecodedConnectFlags, u16)> {
340        // Connect flags
341        if !buf.has_remaining() {
342            return Err(MqttError::MalformedPacket(
343                "Missing connect flags".to_string(),
344            ));
345        }
346        let flags = buf.get_u8();
347
348        // Parse flags using BeBytes decomposition
349        let decomposed_flags = ConnectFlags::decompose(flags);
350
351        // Validate reserved bit
352        if decomposed_flags.contains(&ConnectFlags::Reserved) {
353            return Err(MqttError::MalformedPacket(
354                "Reserved flag bit must be 0".to_string(),
355            ));
356        }
357
358        let credentials = CredentialFlags {
359            username_flag: decomposed_flags.contains(&ConnectFlags::UsernameFlag),
360            password_flag: decomposed_flags.contains(&ConnectFlags::PasswordFlag),
361        };
362
363        let decoded_flags = DecodedConnectFlags {
364            clean_start: decomposed_flags.contains(&ConnectFlags::CleanStart),
365            will_flag: decomposed_flags.contains(&ConnectFlags::WillFlag),
366            will_qos: ConnectFlags::extract_will_qos(flags),
367            will_retain: decomposed_flags.contains(&ConnectFlags::WillRetain),
368            credentials,
369        };
370
371        // Keep alive
372        if buf.remaining() < 2 {
373            return Err(MqttError::MalformedPacket("Missing keep alive".to_string()));
374        }
375        let keep_alive = buf.get_u16();
376
377        Ok((decoded_flags, keep_alive))
378    }
379
380    /// Decode will message if present
381    fn decode_will<B: Buf>(
382        buf: &mut B,
383        flags: &DecodedConnectFlags,
384        protocol_version: u8,
385    ) -> Result<(Option<WillMessage>, Properties)> {
386        if !flags.will_flag {
387            return Ok((None, Properties::default()));
388        }
389
390        let will_properties = if protocol_version == PROTOCOL_VERSION_V5 {
391            Properties::decode(buf)?
392        } else {
393            Properties::default()
394        };
395
396        let topic = decode_string(buf)?;
397        let payload = decode_binary(buf)?;
398
399        let qos = match flags.will_qos {
400            0 => QoS::AtMostOnce,
401            1 => QoS::AtLeastOnce,
402            2 => QoS::ExactlyOnce,
403            _ => return Err(MqttError::MalformedPacket("Invalid will QoS".to_string())),
404        };
405
406        // Convert Properties to WillProperties
407        let will_props = Self::properties_to_will_properties(&will_properties);
408
409        let will = WillMessage {
410            topic,
411            payload: payload.to_vec(),
412            qos,
413            retain: flags.will_retain,
414            properties: will_props,
415        };
416
417        Ok((Some(will), will_properties))
418    }
419
420    /// Convert Properties to WillProperties
421    fn properties_to_will_properties(props: &Properties) -> WillProperties {
422        use crate::protocol::v5::properties::{PropertyId, PropertyValue};
423
424        let mut will_props = WillProperties::default();
425
426        // Extract will delay interval
427        if let Some(PropertyValue::FourByteInteger(delay)) =
428            props.get(PropertyId::WillDelayInterval)
429        {
430            will_props.will_delay_interval = Some(*delay);
431        }
432
433        // Extract payload format indicator
434        if let Some(PropertyValue::Byte(indicator)) = props.get(PropertyId::PayloadFormatIndicator)
435        {
436            will_props.payload_format_indicator = Some(*indicator != 0);
437        }
438
439        // Extract message expiry interval
440        if let Some(PropertyValue::FourByteInteger(expiry)) =
441            props.get(PropertyId::MessageExpiryInterval)
442        {
443            will_props.message_expiry_interval = Some(*expiry);
444        }
445
446        // Extract content type
447        if let Some(PropertyValue::Utf8String(content_type)) = props.get(PropertyId::ContentType) {
448            will_props.content_type = Some(content_type.clone());
449        }
450
451        // Extract response topic
452        if let Some(PropertyValue::Utf8String(topic)) = props.get(PropertyId::ResponseTopic) {
453            will_props.response_topic = Some(topic.clone());
454        }
455
456        // Extract correlation data
457        if let Some(PropertyValue::BinaryData(data)) = props.get(PropertyId::CorrelationData) {
458            will_props.correlation_data = Some(data.to_vec());
459        }
460
461        // Extract user properties
462        if let Some(values) = props.get_all(PropertyId::UserProperty) {
463            for value in values {
464                if let PropertyValue::Utf8StringPair(key, val) = value {
465                    will_props.user_properties.push((key.clone(), val.clone()));
466                }
467            }
468        }
469
470        will_props
471    }
472
473    /// Decode username and password if present
474    fn decode_credentials<B: Buf>(
475        buf: &mut B,
476        flags: &DecodedConnectFlags,
477    ) -> Result<(Option<String>, Option<Bytes>)> {
478        let username = if flags.credentials.username_flag {
479            Some(decode_string(buf)?)
480        } else {
481            None
482        };
483
484        let password = if flags.credentials.password_flag {
485            Some(decode_binary(buf)?)
486        } else {
487            None
488        };
489
490        // Validate password without username
491        if password.is_some() && username.is_none() {
492            return Err(MqttError::MalformedPacket(
493                "Password without username is not allowed".to_string(),
494            ));
495        }
496
497        Ok((username, password))
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504    use crate::time::Duration;
505    use bytes::BytesMut;
506
507    #[test]
508    fn test_connect_packet_basic() {
509        let options = ConnectOptions::new("test-client");
510        let packet = ConnectPacket::new(options);
511
512        assert_eq!(packet.protocol_version, PROTOCOL_VERSION_V5);
513        assert!(packet.clean_start);
514        assert_eq!(packet.keep_alive, 60);
515        assert_eq!(packet.client_id, "test-client");
516        assert!(packet.username.is_none());
517        assert!(packet.password.is_none());
518        assert!(packet.will.is_none());
519    }
520
521    #[test]
522    fn test_connect_packet_with_credentials() {
523        let options = ConnectOptions::new("test-client").with_credentials("user", b"pass");
524        let packet = ConnectPacket::new(options);
525
526        assert_eq!(packet.username, Some("user".to_string()));
527        assert_eq!(packet.password, Some(b"pass".to_vec()));
528    }
529
530    #[test]
531    fn test_connect_packet_with_will() {
532        let will = WillMessage::new("will/topic", b"will payload")
533            .with_qos(QoS::AtLeastOnce)
534            .with_retain(true);
535        let options = ConnectOptions::new("test-client").with_will(will);
536        let packet = ConnectPacket::new(options);
537
538        assert!(packet.will.is_some());
539        let will = packet.will.as_ref().unwrap();
540        assert_eq!(will.topic, "will/topic");
541        assert_eq!(will.payload, b"will payload");
542        assert_eq!(will.qos, QoS::AtLeastOnce);
543        assert!(will.retain);
544    }
545
546    #[test]
547    fn test_connect_flags() {
548        let packet = ConnectPacket::new(ConnectOptions::new("test"));
549        assert_eq!(packet.connect_flags(), 0x02); // Clean start only
550
551        let options = ConnectOptions::new("test")
552            .with_clean_start(false)
553            .with_credentials("user", b"pass");
554        let packet = ConnectPacket::new(options);
555        assert_eq!(packet.connect_flags(), 0xC0); // Username + Password
556
557        let will = WillMessage::new("topic", b"payload")
558            .with_qos(QoS::ExactlyOnce)
559            .with_retain(true);
560        let options = ConnectOptions::new("test").with_will(will);
561        let packet = ConnectPacket::new(options);
562        assert_eq!(packet.connect_flags(), 0x36); // Clean start + Will + QoS 2 + Retain
563    }
564
565    #[test]
566    fn test_connect_encode_decode_v5() {
567        let options = ConnectOptions::new("test-client-123")
568            .with_keep_alive(Duration::from_secs(120))
569            .with_credentials("testuser", b"testpass");
570        let packet = ConnectPacket::new(options);
571
572        let mut buf = BytesMut::new();
573        packet.encode(&mut buf).unwrap();
574
575        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
576        assert_eq!(fixed_header.packet_type, PacketType::Connect);
577
578        let decoded = ConnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
579        assert_eq!(decoded.protocol_version, PROTOCOL_VERSION_V5);
580        assert_eq!(decoded.client_id, "test-client-123");
581        assert_eq!(decoded.keep_alive, 120);
582        assert_eq!(decoded.username, Some("testuser".to_string()));
583        assert_eq!(decoded.password, Some(b"testpass".to_vec()));
584    }
585
586    #[test]
587    fn test_connect_encode_decode_v311() {
588        let options = ConnectOptions::new("mqtt-311-client");
589        let packet = ConnectPacket::new_v311(options);
590
591        let mut buf = BytesMut::new();
592        packet.encode(&mut buf).unwrap();
593
594        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
595        let decoded = ConnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
596
597        assert_eq!(decoded.protocol_version, PROTOCOL_VERSION_V311);
598        assert_eq!(decoded.client_id, "mqtt-311-client");
599    }
600
601    #[test]
602    fn test_connect_invalid_protocol_name() {
603        let mut buf = BytesMut::new();
604        encode_string(&mut buf, "INVALID").unwrap();
605        buf.put_u8(5);
606
607        let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
608        let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
609        assert!(result.is_err());
610    }
611
612    #[test]
613    fn test_connect_invalid_protocol_version() {
614        let mut buf = BytesMut::new();
615        encode_string(&mut buf, "MQTT").unwrap();
616        buf.put_u8(99); // Invalid version
617
618        let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
619        let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
620        assert!(result.is_err());
621    }
622
623    #[test]
624    fn test_connect_password_without_username() {
625        let mut buf = BytesMut::new();
626        encode_string(&mut buf, "MQTT").unwrap();
627        buf.put_u8(5); // v5.0
628        buf.put_u8(0x40); // Password flag only
629        buf.put_u16(60); // Keep alive
630        buf.put_u8(0); // Empty properties
631        encode_string(&mut buf, "client").unwrap();
632        encode_binary(&mut buf, b"password").unwrap();
633
634        let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
635        let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
636        assert!(result.is_err());
637    }
638}