mqute_codec/protocol/common/
connect.rs

1//! # Connect Packet
2//!
3//! This module provides structures and utilities for handling the MQTT Connect packet,
4//! which is used to initiate a connection between a client and an MQTT broker.
5
6use crate::codec::util::{decode_byte, decode_string, decode_word, encode_string};
7use crate::protocol::common::frame::WillFrame;
8use crate::protocol::Protocol;
9use crate::Error;
10use bit_field::BitField;
11use bytes::{Buf, BufMut, Bytes, BytesMut};
12
13/// Represents the header of the MQTT Connect packet.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub(crate) struct ConnectHeader<T> {
16    pub protocol: Protocol,
17    pub flags: u8,
18    pub keep_alive: u16,
19    pub properties: Option<T>,
20}
21
22impl<T> ConnectHeader<T> {
23    /// Creates a new `ConnectHeader`.
24    pub(crate) fn new(
25        protocol: Protocol,
26        flags: u8,
27        keep_alive: u16,
28        properties: Option<T>,
29    ) -> Self {
30        ConnectHeader {
31            protocol,
32            flags,
33            keep_alive,
34            properties,
35        }
36    }
37
38    /// Calculates the length of the primary encoded header.
39    ///
40    /// This includes the protocol name, protocol level, flags, and keep-alive duration.
41    pub(crate) fn primary_encoded_len(&self) -> usize {
42        2 + self.protocol.name().len() // Protocol name string
43            + 1                        // Protocol level
44            + 1                        // Connect flags
45            + 2 // Keep alive
46    }
47
48    /// Encodes the primary header fields into the provided buffer.
49    pub(crate) fn primary_encode(&self, buf: &mut BytesMut) {
50        // Encode the protocol name
51        encode_string(buf, self.protocol.name());
52
53        // Add the protocol level
54        buf.put_u8(self.protocol.into());
55
56        // Add the flags
57        buf.put_u8(self.flags);
58
59        // Add the keep alive timeout
60        buf.put_u16(self.keep_alive);
61    }
62
63    /// Decodes the primary header fields from the provided buffer.
64    pub(crate) fn primary_decode(buf: &mut Bytes) -> Result<Self, Error> {
65        let protocol_name = decode_string(buf)?;
66
67        let protocol: Protocol = buf.get_u8().try_into()?;
68
69        if protocol_name != protocol.name() {
70            return Err(Error::InvalidProtocolName(protocol_name));
71        }
72
73        let flags = decode_byte(buf)?;
74        let keep_alive = decode_word(buf)?;
75
76        Ok(ConnectHeader {
77            protocol,
78            flags,
79            keep_alive,
80            properties: None,
81        })
82    }
83}
84
85const PASSWORD: usize = 6;
86const USERNAME: usize = 7;
87
88/// Represents authentication information for an MQTT connection.
89///
90/// The `Credentials` struct encapsulates the username and optional password used for authenticating
91/// a client with an MQTT broker. It provides methods for creating, encoding, decoding, and
92/// manipulating authentication data.
93///
94/// # Examples
95///
96/// ```rust
97/// use mqute_codec::protocol::Credentials;
98///
99/// let credentials = Credentials::login("user", "pass");
100/// assert_eq!(credentials.username(), "user");
101/// assert_eq!(credentials.password(), Some("pass".to_string()));
102/// ```
103#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct Credentials {
105    /// The username for authentication.
106    username: String,
107
108    /// An optional password for authentication
109    password: Option<String>,
110}
111
112impl Credentials {
113    /// Creates a new `Credentials` instance.
114    pub fn new<T>(username: T, password: Option<String>) -> Self
115    where
116        T: Into<String>,
117    {
118        Credentials {
119            username: username.into(),
120            password,
121        }
122    }
123
124    /// Creates a new `Credentials` instance with only a username.
125    pub fn with_name<T: Into<String>>(username: T) -> Self {
126        Self::new(username, None)
127    }
128
129    /// Creates a new `Credentials` instance with both a username and password.
130    pub fn login<T: Into<String>, U: Into<String>>(username: T, password: U) -> Self {
131        Self::new(username.into(), Some(password.into()))
132    }
133
134    /// Returns the username.
135    pub fn username(&self) -> String {
136        self.username.clone()
137    }
138
139    /// Returns the optional password.
140    pub fn password(&self) -> Option<String> {
141        self.password.clone()
142    }
143
144    /// Calculates the encoded length of the `Credentials` structure.
145    ///
146    /// This is used to determine the size of the buffer required to encode the `Credentials` data.
147    pub(crate) fn encoded_len(&self) -> usize {
148        let mut size = 2 + self.username.len(); // 2 bytes for string length + username length
149        if let Some(password) = self.password.as_ref() {
150            size += 2 + password.len(); // 2 bytes for string length + password length
151        }
152        size
153    }
154
155    /// Encodes the `Credentials` structure into the provided buffer.
156    pub(crate) fn encode(&self, buf: &mut BytesMut) {
157        encode_string(buf, &self.username);
158
159        if let Some(password) = self.password.as_ref() {
160            encode_string(buf, password);
161        }
162    }
163
164    /// Updates the connection flags based on the presence of a username and password.
165    pub(crate) fn update_flags(&self, flags: &mut u8) {
166        // Update username flag
167        flags.set_bit(USERNAME, true);
168
169        // Update password flag
170        flags.set_bit(PASSWORD, self.password.is_some());
171    }
172
173    /// Decodes the `Credentials` structure from the provided buffer and flags.
174    pub(crate) fn decode(buf: &mut Bytes, flags: u8) -> Result<Option<Self>, Error> {
175        if !flags.get_bit(USERNAME) {
176            return Ok(None);
177        }
178
179        let username = decode_string(buf)?;
180
181        let password = if flags.get_bit(PASSWORD) {
182            Some(decode_string(buf)?)
183        } else {
184            None
185        };
186
187        Ok(Some(Credentials::new(username, password)))
188    }
189}
190
191/// Represents the payload of the MQTT Connect packet.
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub(crate) struct ConnectPayload<T> {
194    pub client_id: String,
195    pub credentials: Option<Credentials>,
196    pub will: Option<T>,
197}
198
199impl<T> ConnectPayload<T>
200where
201    T: WillFrame,
202{
203    /// Creates a new `ConnectPayload`.
204    pub(crate) fn new<S: Into<String>>(
205        client_id: S,
206        credentials: Option<Credentials>,
207        will: Option<T>,
208    ) -> Self {
209        ConnectPayload {
210            client_id: client_id.into(),
211            credentials,
212            will,
213        }
214    }
215
216    /// Decodes the `ConnectPayload` from the provided buffer and flags.
217    pub(crate) fn decode(payload: &mut Bytes, flags: u8) -> Result<Self, Error> {
218        let client_id = decode_string(payload)?;
219
220        let will = T::decode(payload, flags)?;
221        let credentials = Credentials::decode(payload, flags)?;
222
223        Ok(ConnectPayload {
224            client_id,
225            credentials,
226            will,
227        })
228    }
229
230    /// Encodes the `ConnectPayload` into the provided buffer.
231    pub(crate) fn encode(&self, buf: &mut BytesMut) -> Result<(), Error> {
232        // Encode the client id
233        encode_string(buf, &self.client_id);
234
235        if let Some(will) = self.will.as_ref() {
236            will.encode(buf)?;
237        }
238
239        if let Some(credentials) = self.credentials.as_ref() {
240            credentials.encode(buf);
241        }
242
243        Ok(())
244    }
245
246    /// Calculates the encoded length of the `ConnectPayload`.
247    pub(crate) fn encoded_len(&self) -> usize {
248        2 + self.client_id.len() +            // Client ID
249            self.will                         // WillFlag
250                .as_ref()
251                .map(|will| will.encoded_len())
252                .unwrap_or(0) +
253            self.credentials                         // Credentials
254                .as_ref()
255                .map(|credentials| credentials.encoded_len())
256                .unwrap_or(0)
257    }
258}
259
260/// Generates a Connect packet structure with specific properties and will message types.
261///
262/// The `connect!` macro is used to generate a Connect packet structure that includes
263/// the header, payload, and encoding/decoding logic for a specific MQTT protocol version.
264macro_rules! connect {
265    ($name:ident <$property:ident, $will:ident>, $proto:expr) => {
266        /// Represents an MQTT `Connect` packet
267        ///
268        /// This packet initiates a connection between client and broker and contains
269        /// all necessary parameters for the session.
270        ///
271        /// # Example
272        ///
273        /// ```rust
274        /// use std::time::Duration;
275        /// use bytes::Bytes;
276        /// use mqute_codec::protocol::{v5, Credentials, Protocol, QoS};
277        ///
278        /// let connect = v5::Connect::new(
279        ///     "client",
280        ///     Some(Credentials::login("user", "pass")),
281        ///     Some(v5::Will::new(
282        ///         None,
283        ///         "device/status",
284        ///         Bytes::from("disconnected"),
285        ///         QoS::ExactlyOnce,
286        ///         true
287        ///     )),
288        ///     Duration::from_secs(30).as_secs() as u16,
289        ///     true
290        /// );
291        /// assert!(connect.will().is_some());
292        /// assert_eq!(connect.protocol(), Protocol::V5);
293        /// assert_eq!(connect.client_id(), "client");
294        /// ```
295        #[derive(Debug, Clone, PartialEq, Eq)]
296        pub struct $name {
297            header: $crate::protocol::common::ConnectHeader<$property>,
298            payload: $crate::protocol::common::ConnectPayload<$will>,
299        }
300
301        const CLEAN_SESSION: usize = 1;
302
303        impl $name {
304            fn from_scratch<S: Into<String>>(
305                client_id: S,
306                credentials: Option<$crate::protocol::common::Credentials>,
307                will: Option<$will>,
308                properties: Option<$property>,
309                keep_alive: u16,
310                clean_session: bool,
311            ) -> Self {
312                use bit_field::BitField;
313                use $crate::protocol::common::WillFrame;
314
315                let mut flags = 0u8;
316
317                flags.set_bit(CLEAN_SESSION, clean_session);
318
319                if let Some(credentials) = credentials.as_ref() {
320                    credentials.update_flags(&mut flags);
321                }
322
323                if let Some(will) = will.as_ref() {
324                    will.update_flags(&mut flags);
325                }
326
327                let header = $crate::protocol::common::ConnectHeader::<$property>::new(
328                    $proto, flags, keep_alive, properties,
329                );
330                let payload = $crate::protocol::common::ConnectPayload::<$will>::new(
331                    client_id,
332                    credentials,
333                    will,
334                );
335
336                Self { header, payload }
337            }
338
339            /// Creates a new Connect packet with basic parameters
340            pub fn new<S: Into<String>>(
341                client_id: S,
342                credentials: Option<$crate::protocol::Credentials>,
343                will: Option<$will>,
344                keep_alive: u16,
345                clean_session: bool,
346            ) -> Self {
347                Self::from_scratch(
348                    client_id,
349                    credentials,
350                    will,
351                    None,
352                    keep_alive,
353                    clean_session,
354                )
355            }
356
357            /// Returns the protocol version being used
358            pub fn protocol(&self) -> $crate::protocol::Protocol {
359                self.header.protocol
360            }
361
362            /// Returns the keep alive time in seconds
363            pub fn keep_alive(&self) -> u16 {
364                self.header.keep_alive
365            }
366
367            /// Returns whether this is a clean session
368            pub fn clean_session(&self) -> bool {
369                use bit_field::BitField;
370                self.header.flags.get_bit(CLEAN_SESSION)
371            }
372
373            /// Returns the client identifier
374            pub fn client_id(&self) -> String {
375                self.payload.client_id.clone()
376            }
377
378            /// Returns the authentication credentials if present
379            pub fn credentials(&self) -> Option<$crate::protocol::common::Credentials> {
380                self.payload.credentials.clone()
381            }
382
383            /// Returns the will message if present
384            pub fn will(&self) -> Option<$will> {
385                self.payload.will.clone()
386            }
387        }
388
389        impl $crate::codec::Encode for $name {
390            fn encode(&self, buf: &mut bytes::BytesMut) -> Result<(), $crate::Error> {
391                use $crate::protocol::common::ConnectFrame;
392
393                let header = $crate::protocol::FixedHeader::new(
394                    $crate::protocol::PacketType::Connect,
395                    self.payload_len(),
396                );
397
398                // Encode fixed header
399                header.encode(buf)?;
400
401                // Encode variable header
402                self.header.encode(buf)?;
403
404                // Encode payload
405                self.payload.encode(buf)
406            }
407
408            fn payload_len(&self) -> usize {
409                use $crate::protocol::common::ConnectFrame;
410                self.header.encoded_len() + self.payload.encoded_len()
411            }
412        }
413
414        impl $crate::codec::Decode for $name {
415            fn decode(mut packet: $crate::codec::RawPacket) -> Result<Self, $crate::Error> {
416                use $crate::protocol::common::ConnectFrame;
417
418                if packet.header.packet_type() != $crate::protocol::PacketType::Connect
419                    || !packet.header.flags().is_default()
420                {
421                    return Err($crate::Error::MalformedPacket);
422                }
423
424                let header = $crate::protocol::common::ConnectHeader::<$property>::decode(
425                    &mut packet.payload,
426                )?;
427
428                if header.protocol != $proto {
429                    return Err($crate::Error::ProtocolNotSupported);
430                }
431                let payload = $crate::protocol::common::ConnectPayload::<$will>::decode(
432                    &mut packet.payload,
433                    header.flags,
434                )?;
435
436                Ok(Self { header, payload })
437            }
438        }
439    };
440}
441
442pub(crate) use connect;