mqute_codec/protocol/common/
connect.rs

1//! # Connect Packet
2//!
3//! This module provides structures and utilities for handling the MQTT Connect packet,
4//! which is used to initiate a connection between a client and an MQTT broker.
5
6use crate::Error;
7use crate::codec::util::{decode_byte, decode_string, decode_word, encode_string};
8use crate::protocol::Protocol;
9use crate::protocol::common::frame::WillFrame;
10use bit_field::BitField;
11use bytes::{Buf, BufMut, Bytes, BytesMut};
12use std::time::Duration;
13
14/// Represents the header of the MQTT Connect packet.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub(crate) struct ConnectHeader<T> {
17    pub protocol: Protocol,
18    pub flags: u8,
19    pub keep_alive: Duration,
20    pub properties: Option<T>,
21}
22
23impl<T> ConnectHeader<T> {
24    /// Creates a new `ConnectHeader`.
25    pub(crate) fn new(
26        protocol: Protocol,
27        flags: u8,
28        keep_alive: Duration,
29        properties: Option<T>,
30    ) -> Self {
31        ConnectHeader {
32            protocol,
33            flags,
34            keep_alive,
35            properties,
36        }
37    }
38
39    /// Calculates the length of the primary encoded header.
40    ///
41    /// This includes the protocol name, protocol level, flags, and keep-alive duration.
42    pub(crate) fn primary_encoded_len(&self) -> usize {
43        2 + self.protocol.name().len() // Protocol name string
44            + 1                        // Protocol level
45            + 1                        // Connect flags
46            + 2 // Keep alive
47    }
48
49    /// Encodes the primary header fields into the provided buffer.
50    pub(crate) fn primary_encode(&self, buf: &mut BytesMut) {
51        // Encode the protocol name
52        encode_string(buf, self.protocol.name());
53
54        // Add the protocol level
55        buf.put_u8(self.protocol.into());
56
57        // Add the flags
58        buf.put_u8(self.flags);
59
60        // Add the keep alive timeout
61        buf.put_u16(self.keep_alive.as_secs() as u16);
62    }
63
64    /// Decodes the primary header fields from the provided buffer.
65    pub(crate) fn primary_decode(buf: &mut Bytes) -> Result<Self, Error> {
66        let protocol_name = decode_string(buf)?;
67
68        let protocol: Protocol = buf.get_u8().try_into()?;
69
70        if protocol_name != protocol.name() {
71            return Err(Error::InvalidProtocolName(protocol_name));
72        }
73
74        let flags = decode_byte(buf)?;
75        let keep_alive = decode_word(buf)?;
76
77        Ok(ConnectHeader {
78            protocol,
79            flags,
80            keep_alive: Duration::from_secs(keep_alive as u64),
81            properties: None,
82        })
83    }
84}
85
86const PASSWORD: usize = 6;
87const USERNAME: usize = 7;
88
89/// Represents authentication information for an MQTT connection.
90///
91/// The `Credentials` struct encapsulates the username and optional password used for authenticating
92/// a client with an MQTT broker. It provides methods for creating, encoding, decoding, and
93/// manipulating authentication data.
94///
95/// # Examples
96///
97/// ```rust
98/// use mqute_codec::protocol::Credentials;
99///
100/// let credentials = Credentials::full("user", "pass");
101/// assert_eq!(credentials.username(), "user");
102/// assert_eq!(credentials.password(), Some("pass".to_string()));
103/// ```
104#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct Credentials {
106    /// The username for authentication.
107    username: String,
108
109    /// An optional password for authentication
110    password: Option<String>,
111}
112
113impl Credentials {
114    /// Creates a new `Credentials` instance.
115    pub fn new<T>(username: T, password: Option<String>) -> Self
116    where
117        T: Into<String>,
118    {
119        Credentials {
120            username: username.into(),
121            password,
122        }
123    }
124
125    /// Creates a new `Credentials` instance with only a username.
126    pub fn with_name<T: Into<String>>(username: T) -> Self {
127        Self::new(username, None)
128    }
129
130    /// Creates a new `Credentials` instance with both a username and password.
131    pub fn full<T: Into<String>, U: Into<String>>(username: T, password: U) -> Self {
132        Self::new(username.into(), Some(password.into()))
133    }
134
135    /// Returns the username.
136    pub fn username(&self) -> String {
137        self.username.clone()
138    }
139
140    /// Returns the optional password.
141    pub fn password(&self) -> Option<String> {
142        self.password.clone()
143    }
144
145    /// Calculates the encoded length of the `Credentials` structure.
146    ///
147    /// This is used to determine the size of the buffer required to encode the `Credentials` data.
148    pub(crate) fn encoded_len(&self) -> usize {
149        let mut size = 2 + self.username.len(); // 2 bytes for string length + username length
150        if let Some(password) = self.password.as_ref() {
151            size += 2 + password.len(); // 2 bytes for string length + password length
152        }
153        size
154    }
155
156    /// Encodes the `Credentials` structure into the provided buffer.
157    pub(crate) fn encode(&self, buf: &mut BytesMut) {
158        encode_string(buf, &self.username);
159
160        if let Some(password) = self.password.as_ref() {
161            encode_string(buf, password);
162        }
163    }
164
165    /// Updates the connection flags based on the presence of a username and password.
166    pub(crate) fn update_flags(&self, flags: &mut u8) {
167        // Update username flag
168        flags.set_bit(USERNAME, true);
169
170        // Update password flag
171        flags.set_bit(PASSWORD, self.password.is_some());
172    }
173
174    /// Decodes the `Credentials` structure from the provided buffer and flags.
175    pub(crate) fn decode(buf: &mut Bytes, flags: u8) -> Result<Option<Self>, Error> {
176        if !flags.get_bit(USERNAME) {
177            return Ok(None);
178        }
179
180        let username = decode_string(buf)?;
181
182        let password = if flags.get_bit(PASSWORD) {
183            Some(decode_string(buf)?)
184        } else {
185            None
186        };
187
188        Ok(Some(Credentials::new(username, password)))
189    }
190}
191
192/// Represents the payload of the MQTT Connect packet.
193#[derive(Debug, Clone, PartialEq, Eq)]
194pub(crate) struct ConnectPayload<T> {
195    pub client_id: String,
196    pub credentials: Option<Credentials>,
197    pub will: Option<T>,
198}
199
200impl<T> ConnectPayload<T>
201where
202    T: WillFrame,
203{
204    /// Creates a new `ConnectPayload`.
205    pub(crate) fn new<S: Into<String>>(
206        client_id: S,
207        credentials: Option<Credentials>,
208        will: Option<T>,
209    ) -> Self {
210        ConnectPayload {
211            client_id: client_id.into(),
212            credentials,
213            will,
214        }
215    }
216
217    /// Decodes the `ConnectPayload` from the provided buffer and flags.
218    pub(crate) fn decode(payload: &mut Bytes, flags: u8) -> Result<Self, Error> {
219        let client_id = decode_string(payload)?;
220
221        let will = T::decode(payload, flags)?;
222        let credentials = Credentials::decode(payload, flags)?;
223
224        Ok(ConnectPayload {
225            client_id,
226            credentials,
227            will,
228        })
229    }
230
231    /// Encodes the `ConnectPayload` into the provided buffer.
232    pub(crate) fn encode(&self, buf: &mut BytesMut) -> Result<(), Error> {
233        // Encode the client id
234        encode_string(buf, &self.client_id);
235
236        if let Some(will) = self.will.as_ref() {
237            will.encode(buf)?;
238        }
239
240        if let Some(credentials) = self.credentials.as_ref() {
241            credentials.encode(buf);
242        }
243
244        Ok(())
245    }
246
247    /// Calculates the encoded length of the `ConnectPayload`.
248    pub(crate) fn encoded_len(&self) -> usize {
249        2 + self.client_id.len() +            // Client ID
250            self.will                         // WillFlag
251                .as_ref()
252                .map(|will| will.encoded_len())
253                .unwrap_or(0) +
254            self.credentials                         // Credentials
255                .as_ref()
256                .map(|credentials| credentials.encoded_len())
257                .unwrap_or(0)
258    }
259}
260
261/// Generates a Connect packet structure with specific properties and will message types.
262///
263/// The `connect!` macro is used to generate a Connect packet structure that includes
264/// the header, payload, and encoding/decoding logic for a specific MQTT protocol version.
265macro_rules! connect {
266    ($name:ident <$property:ident, $will:ident>, $proto:expr) => {
267        /// Represents an MQTT `Connect` packet
268        ///
269        /// This packet initiates a connection between client and broker and contains
270        /// all necessary parameters for the session.
271        ///
272        /// # Example
273        ///
274        /// ```rust
275        /// use std::time::Duration;
276        /// use bytes::Bytes;
277        /// use mqute_codec::protocol::{v5, Credentials, Protocol, QoS};
278        ///
279        /// let connect = v5::Connect::new(
280        ///     "client",
281        ///     Some(Credentials::full("user", "pass")),
282        ///     Some(v5::Will::new(
283        ///         None,
284        ///         "device/status",
285        ///         Bytes::from("disconnected"),
286        ///         QoS::ExactlyOnce,
287        ///         true
288        ///     )),
289        ///     Duration::from_secs(30),
290        ///     true
291        /// );
292        /// assert!(connect.will().is_some());
293        /// assert_eq!(connect.protocol(), Protocol::V5);
294        /// assert_eq!(connect.client_id(), "client");
295        /// ```
296        #[derive(Debug, Clone, PartialEq, Eq)]
297        pub struct $name {
298            header: $crate::protocol::common::ConnectHeader<$property>,
299            payload: $crate::protocol::common::ConnectPayload<$will>,
300        }
301
302        const CLEAN_SESSION: usize = 1;
303
304        impl $name {
305            fn from_scratch<S: Into<String>>(
306                client_id: S,
307                credentials: Option<$crate::protocol::common::Credentials>,
308                will: Option<$will>,
309                properties: Option<$property>,
310                keep_alive: std::time::Duration,
311                clean_session: bool,
312            ) -> Self {
313                use bit_field::BitField;
314                use $crate::protocol::common::WillFrame;
315
316                if (keep_alive.as_secs() > u16::MAX as u64) {
317                    panic!("Invalid 'keep alive' value");
318                }
319
320                let mut flags = 0u8;
321
322                flags.set_bit(CLEAN_SESSION, clean_session);
323
324                if let Some(credentials) = credentials.as_ref() {
325                    credentials.update_flags(&mut flags);
326                }
327
328                if let Some(will) = will.as_ref() {
329                    will.update_flags(&mut flags);
330                }
331
332                let header = $crate::protocol::common::ConnectHeader::<$property>::new(
333                    $proto, flags, keep_alive, properties,
334                );
335                let payload = $crate::protocol::common::ConnectPayload::<$will>::new(
336                    client_id,
337                    credentials,
338                    will,
339                );
340
341                Self { header, payload }
342            }
343
344            /// Creates a new Connect packet with basic parameters
345            ///
346            /// # Panics
347            ///
348            /// Panics if the value of the "keep alive" parameter exceeds 65535
349            pub fn new<S: Into<String>>(
350                client_id: S,
351                credentials: Option<$crate::protocol::Credentials>,
352                will: Option<$will>,
353                keep_alive: std::time::Duration,
354                clean_session: bool,
355            ) -> Self {
356                Self::from_scratch(
357                    client_id,
358                    credentials,
359                    will,
360                    None,
361                    keep_alive,
362                    clean_session,
363                )
364            }
365
366            /// Returns the protocol version being used
367            pub fn protocol(&self) -> $crate::protocol::Protocol {
368                self.header.protocol
369            }
370
371            /// Returns the keep alive time in seconds
372            pub fn keep_alive(&self) -> std::time::Duration {
373                self.header.keep_alive
374            }
375
376            /// Returns whether this is a clean session
377            pub fn clean_session(&self) -> bool {
378                use bit_field::BitField;
379                self.header.flags.get_bit(CLEAN_SESSION)
380            }
381
382            /// Returns the client identifier
383            pub fn client_id(&self) -> String {
384                self.payload.client_id.clone()
385            }
386
387            /// Returns the authentication credentials if present
388            pub fn credentials(&self) -> Option<$crate::protocol::common::Credentials> {
389                self.payload.credentials.clone()
390            }
391
392            /// Returns the will message if present
393            pub fn will(&self) -> Option<$will> {
394                self.payload.will.clone()
395            }
396        }
397
398        impl $crate::codec::Encode for $name {
399            fn encode(&self, buf: &mut bytes::BytesMut) -> Result<(), $crate::Error> {
400                use $crate::protocol::common::ConnectFrame;
401
402                let header = $crate::protocol::FixedHeader::new(
403                    $crate::protocol::PacketType::Connect,
404                    self.payload_len(),
405                );
406
407                // Encode fixed header
408                header.encode(buf)?;
409
410                // Encode variable header
411                self.header.encode(buf)?;
412
413                // Encode payload
414                self.payload.encode(buf)
415            }
416
417            fn payload_len(&self) -> usize {
418                use $crate::protocol::common::ConnectFrame;
419                self.header.encoded_len() + self.payload.encoded_len()
420            }
421        }
422
423        impl $crate::codec::Decode for $name {
424            fn decode(mut packet: $crate::codec::RawPacket) -> Result<Self, $crate::Error> {
425                use $crate::protocol::common::ConnectFrame;
426
427                if packet.header.packet_type() != $crate::protocol::PacketType::Connect
428                    || !packet.header.flags().is_default()
429                {
430                    return Err($crate::Error::MalformedPacket);
431                }
432
433                let header = $crate::protocol::common::ConnectHeader::<$property>::decode(
434                    &mut packet.payload,
435                )?;
436
437                if header.protocol != $proto {
438                    return Err($crate::Error::ProtocolNotSupported);
439                }
440                let payload = $crate::protocol::common::ConnectPayload::<$will>::decode(
441                    &mut packet.payload,
442                    header.flags,
443                )?;
444
445                Ok(Self { header, payload })
446            }
447        }
448
449        impl $crate::protocol::traits::Connect for $name {}
450    };
451}
452
453pub(crate) use connect;