Skip to main content

mqtt4bytes/packets/
connect.rs

1use super::*;
2use crate::*;
3use alloc::string::String;
4use alloc::vec::Vec;
5use bytes::{Buf, Bytes};
6use core::fmt;
7
8/// Connection packet initiated by the client
9#[derive(Clone, PartialEq)]
10pub struct Connect {
11    /// Mqtt protocol version
12    pub protocol: Protocol,
13    /// Mqtt keep alive time
14    pub keep_alive: u16,
15    /// Client Id
16    pub client_id: String,
17    /// Clean session. Asks the broker to clear previous state
18    pub clean_session: bool,
19    /// Will that broker needs to publish when the client disconnects
20    pub last_will: Option<LastWill>,
21    /// Username and password
22    pub login: Option<Login>,
23}
24
25impl Connect {
26    pub fn new<S: Into<String>>(id: S) -> Connect {
27        Connect {
28            protocol: Protocol::MQTT(4),
29            keep_alive: 10,
30            client_id: id.into(),
31            clean_session: true,
32            last_will: None,
33            login: None,
34        }
35    }
36
37    pub(crate) fn assemble(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Connect, Error> {
38        let variable_header_index = fixed_header.fixed_len;
39        bytes.advance(variable_header_index);
40        let protocol_name = read_mqtt_string(&mut bytes)?;
41        let protocol_level = bytes.get_u8();
42        if protocol_name != "MQTT" {
43            return Err(Error::InvalidProtocol);
44        }
45
46        let protocol = match protocol_level {
47            4 => Protocol::MQTT(4),
48            num => return Err(Error::InvalidProtocolLevel(num)),
49        };
50
51        let connect_flags = bytes.get_u8();
52        let clean_session = (connect_flags & 0b10) != 0;
53        let keep_alive = bytes.get_u16();
54        let client_id = read_mqtt_string(&mut bytes)?;
55        let last_will = LastWill::extract(connect_flags, &mut bytes)?;
56        let login = Login::extract(connect_flags, &mut bytes)?;
57
58        let connect = Connect {
59            protocol,
60            keep_alive,
61            client_id,
62            clean_session,
63            last_will,
64            login,
65        };
66
67        Ok(connect)
68    }
69
70    /// Variable header length
71    fn len(&self) -> usize {
72        let mut len = 2 + "MQTT".len() // protocol name
73                              + 1  // protocol version
74                              + 1  // connect flags
75                              + 2; // keep alive
76
77        len += 2 + self.client_id.len();
78
79        // last will len
80        if let Some(last_will) = &self.last_will {
81            len += last_will.len();
82        }
83
84        // username and password len
85        if let Some(login) = &self.login {
86            len += login.len();
87        }
88
89        len
90    }
91
92    pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
93        let len = self.len();
94        buffer.reserve(len);
95        buffer.put_u8(0b0001_0000);
96        let count = write_remaining_length(buffer, len)?;
97        write_mqtt_string(buffer, "MQTT");
98        buffer.put_u8(0x04);
99        let flags_index = 1 + count + 2 + 4 + 1;
100
101        let mut connect_flags = 0;
102        if self.clean_session {
103            connect_flags |= 0x02;
104        }
105
106        buffer.put_u8(connect_flags);
107        buffer.put_u16(self.keep_alive);
108        write_mqtt_string(buffer, &self.client_id);
109
110        if let Some(last_will) = &self.last_will {
111            connect_flags |= last_will.write(buffer)?;
112        }
113
114        if let Some(login) = &self.login {
115            connect_flags |= login.write(buffer);
116        }
117
118        // update connect flags
119        buffer[flags_index] = connect_flags;
120        Ok(1 + count + len)
121    }
122}
123
124/// LastWill that broker forwards on behalf of the client
125#[derive(Debug, Clone, PartialEq)]
126pub struct LastWill {
127    pub topic: String,
128    pub message: Bytes,
129    pub qos: QoS,
130    pub retain: bool,
131}
132
133impl LastWill {
134    pub fn new(topic: impl Into<String>, qos: QoS, payload: impl Into<Vec<u8>>) -> LastWill {
135        LastWill {
136            topic: topic.into(),
137            message: Bytes::from(payload.into()),
138            qos,
139            retain: false,
140        }
141    }
142
143    fn len(&self) -> usize {
144        let mut len = 0;
145        len += 2 + self.topic.len() + 2 + self.message.len();
146        len
147    }
148
149    fn extract(connect_flags: u8, mut bytes: &mut Bytes) -> Result<Option<LastWill>, Error> {
150        let last_will = match connect_flags & 0b100 {
151            0 if (connect_flags & 0b0011_1000) != 0 => {
152                return Err(Error::IncorrectPacketFormat);
153            }
154            0 => None,
155            _ => {
156                let will_topic = read_mqtt_string(&mut bytes)?;
157                let will_message = read_mqtt_bytes(&mut bytes)?;
158                let will_qos = qos((connect_flags & 0b11000) >> 3)?;
159                Some(LastWill {
160                    topic: will_topic,
161                    message: will_message,
162                    qos: will_qos,
163                    retain: (connect_flags & 0b0010_0000) != 0,
164                })
165            }
166        };
167
168        Ok(last_will)
169    }
170
171    fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
172        let mut connect_flags = 0;
173
174        connect_flags |= 0x04 | (self.qos as u8) << 3;
175        if self.retain {
176            connect_flags |= 0x20;
177        }
178
179        write_mqtt_string(buffer, &self.topic);
180        write_mqtt_bytes(buffer, &self.message);
181        Ok(connect_flags)
182    }
183}
184
185/// Username, password authentication
186#[derive(Debug, Clone, PartialEq)]
187pub struct Login {
188    username: String,
189    password: String,
190}
191
192impl Login {
193    pub fn new<S: Into<String>>(u: S, p: S) -> Login {
194        Login {
195            username: u.into(),
196            password: p.into(),
197        }
198    }
199
200    fn extract(connect_flags: u8, mut bytes: &mut Bytes) -> Result<Option<Login>, Error> {
201        let username = match connect_flags & 0b1000_0000 {
202            0 => String::new(),
203            _ => read_mqtt_string(&mut bytes)?,
204        };
205
206        let password = match connect_flags & 0b0100_0000 {
207            0 => String::new(),
208            _ => read_mqtt_string(&mut bytes)?,
209        };
210
211        if username.is_empty() && password.is_empty() {
212            Ok(None)
213        } else {
214            Ok(Some(Login { username, password }))
215        }
216    }
217
218    fn len(&self) -> usize {
219        let mut len = 0;
220
221        if !self.username.is_empty() {
222            len += 2 + self.username.len();
223        }
224
225        if !self.password.is_empty() {
226            len += 2 + self.password.len();
227        }
228
229        len
230    }
231
232    fn write(&self, buffer: &mut BytesMut) -> u8 {
233        let mut connect_flags = 0;
234        if !self.username.is_empty() {
235            connect_flags |= 0x80;
236            write_mqtt_string(buffer, &self.username);
237        }
238
239        if !self.password.is_empty() {
240            connect_flags |= 0x40;
241            write_mqtt_string(buffer, &self.password);
242        }
243
244        connect_flags
245    }
246}
247
248impl fmt::Debug for Connect {
249    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250        write!(
251            f,
252            "Protocol = {:?}, Keep alive = {:?}, Client id = {}, Clean session = {}",
253            self.protocol, self.keep_alive, self.client_id, self.clean_session,
254        )
255    }
256}
257
258#[cfg(test)]
259mod test {
260    use crate::*;
261    use alloc::borrow::ToOwned;
262    use alloc::vec;
263    use bytes::BytesMut;
264    use pretty_assertions::assert_eq;
265
266    #[test]
267    fn connect_stitching_works_correctlyl() {
268        let mut stream = bytes::BytesMut::new();
269        let packetstream = &[
270            0x10,
271            39, // packet type, flags and remaining len
272            0x00,
273            0x04,
274            b'M',
275            b'Q',
276            b'T',
277            b'T',
278            0x04,        // variable header
279            0b1100_1110, // variable header. +username, +password, -will retain, will qos=1, +last_will, +clean_session
280            0x00,
281            0x0a, // variable header. keep alive = 10 sec
282            0x00,
283            0x04,
284            b't',
285            b'e',
286            b's',
287            b't', // payload. client_id
288            0x00,
289            0x02,
290            b'/',
291            b'a', // payload. will topic = '/a'
292            0x00,
293            0x07,
294            b'o',
295            b'f',
296            b'f',
297            b'l',
298            b'i',
299            b'n',
300            b'e', // payload. variable header. will msg = 'offline'
301            0x00,
302            0x04,
303            b'r',
304            b'u',
305            b'm',
306            b'q', // payload. username = 'rumq'
307            0x00,
308            0x02,
309            b'm',
310            b'q', // payload. password = 'mq'
311            0xDE,
312            0xAD,
313            0xBE,
314            0xEF, // extra packets in the stream
315        ];
316
317        stream.extend_from_slice(&packetstream[..]);
318        let packet = mqtt_read(&mut stream, 100).unwrap();
319        let packet = match packet {
320            Packet::Connect(connect) => connect,
321            packet => panic!("Invalid packet = {:?}", packet),
322        };
323
324        assert_eq!(
325            packet,
326            Connect {
327                protocol: Protocol::MQTT(4),
328                keep_alive: 10,
329                client_id: "test".to_owned(),
330                clean_session: true,
331                last_will: Some(LastWill::new("/a", QoS::AtLeastOnce, "offline")),
332                login: Some(Login::new("rumq", "mq"))
333            }
334        );
335    }
336
337    #[test]
338    fn connack_stitching_works_correctly() {
339        let mut stream = bytes::BytesMut::new();
340        let packetstream = &[
341            0b0010_0000,
342            0x02, // packet type, flags and remaining len
343            0x01,
344            0x00, // variable header. connack flags, connect return code
345            0xDE,
346            0xAD,
347            0xBE,
348            0xEF, // extra packets in the stream
349        ];
350
351        stream.extend_from_slice(&packetstream[..]);
352        let packet = mqtt_read(&mut stream, 100).unwrap();
353        let packet = match packet {
354            Packet::ConnAck(packet) => packet,
355            packet => panic!("Invalid packet = {:?}", packet),
356        };
357
358        assert_eq!(
359            packet,
360            ConnAck {
361                session_present: true,
362                code: ConnectReturnCode::Accepted
363            }
364        );
365    }
366
367    #[test]
368    fn write_connect_mqtt_packet_works() {
369        let connect = Connect {
370            protocol: Protocol::MQTT(4),
371            keep_alive: 10,
372            client_id: "test".to_owned(),
373            clean_session: true,
374            last_will: Some(LastWill::new("/a", QoS::AtLeastOnce, "offline")),
375            login: Some(Login::new("rust", "mq")),
376        };
377
378        let mut buf = BytesMut::new();
379        connect.write(&mut buf).unwrap();
380
381        assert_eq!(
382            buf,
383            vec![
384                0x10,
385                39,
386                0x00,
387                0x04,
388                b'M',
389                b'Q',
390                b'T',
391                b'T',
392                0x04,
393                0b1100_1110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session
394                0x00,
395                0x0a, // 10 sec
396                0x00,
397                0x04,
398                b't',
399                b'e',
400                b's',
401                b't', // client_id
402                0x00,
403                0x02,
404                b'/',
405                b'a', // will topic = '/a'
406                0x00,
407                0x07,
408                b'o',
409                b'f',
410                b'f',
411                b'l',
412                b'i',
413                b'n',
414                b'e', // will msg = 'offline'
415                0x00,
416                0x04,
417                b'r',
418                b'u',
419                b's',
420                b't', // username = 'rust'
421                0x00,
422                0x02,
423                b'm',
424                b'q' // password = 'mq'
425            ]
426        );
427    }
428}