1use super::*;
2use crate::*;
3use alloc::string::String;
4use alloc::vec::Vec;
5use bytes::{Buf, Bytes};
6use core::fmt;
7
8#[derive(Clone, PartialEq)]
10pub struct Connect {
11 pub protocol: Protocol,
13 pub keep_alive: u16,
15 pub client_id: String,
17 pub clean_session: bool,
19 pub last_will: Option<LastWill>,
21 pub login: Option<Login>,
23}
24
25impl Connect {
26 pub fn new<S: Into<String>>(id: S) -> Connect {
27 Connect {
28 protocol: Protocol::MQTT(4),
29 keep_alive: 10,
30 client_id: id.into(),
31 clean_session: true,
32 last_will: None,
33 login: None,
34 }
35 }
36
37 pub(crate) fn assemble(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Connect, Error> {
38 let variable_header_index = fixed_header.fixed_len;
39 bytes.advance(variable_header_index);
40 let protocol_name = read_mqtt_string(&mut bytes)?;
41 let protocol_level = bytes.get_u8();
42 if protocol_name != "MQTT" {
43 return Err(Error::InvalidProtocol);
44 }
45
46 let protocol = match protocol_level {
47 4 => Protocol::MQTT(4),
48 num => return Err(Error::InvalidProtocolLevel(num)),
49 };
50
51 let connect_flags = bytes.get_u8();
52 let clean_session = (connect_flags & 0b10) != 0;
53 let keep_alive = bytes.get_u16();
54 let client_id = read_mqtt_string(&mut bytes)?;
55 let last_will = LastWill::extract(connect_flags, &mut bytes)?;
56 let login = Login::extract(connect_flags, &mut bytes)?;
57
58 let connect = Connect {
59 protocol,
60 keep_alive,
61 client_id,
62 clean_session,
63 last_will,
64 login,
65 };
66
67 Ok(connect)
68 }
69
70 fn len(&self) -> usize {
72 let mut len = 2 + "MQTT".len() + 1 + 1 + 2; len += 2 + self.client_id.len();
78
79 if let Some(last_will) = &self.last_will {
81 len += last_will.len();
82 }
83
84 if let Some(login) = &self.login {
86 len += login.len();
87 }
88
89 len
90 }
91
92 pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
93 let len = self.len();
94 buffer.reserve(len);
95 buffer.put_u8(0b0001_0000);
96 let count = write_remaining_length(buffer, len)?;
97 write_mqtt_string(buffer, "MQTT");
98 buffer.put_u8(0x04);
99 let flags_index = 1 + count + 2 + 4 + 1;
100
101 let mut connect_flags = 0;
102 if self.clean_session {
103 connect_flags |= 0x02;
104 }
105
106 buffer.put_u8(connect_flags);
107 buffer.put_u16(self.keep_alive);
108 write_mqtt_string(buffer, &self.client_id);
109
110 if let Some(last_will) = &self.last_will {
111 connect_flags |= last_will.write(buffer)?;
112 }
113
114 if let Some(login) = &self.login {
115 connect_flags |= login.write(buffer);
116 }
117
118 buffer[flags_index] = connect_flags;
120 Ok(1 + count + len)
121 }
122}
123
124#[derive(Debug, Clone, PartialEq)]
126pub struct LastWill {
127 pub topic: String,
128 pub message: Bytes,
129 pub qos: QoS,
130 pub retain: bool,
131}
132
133impl LastWill {
134 pub fn new(topic: impl Into<String>, qos: QoS, payload: impl Into<Vec<u8>>) -> LastWill {
135 LastWill {
136 topic: topic.into(),
137 message: Bytes::from(payload.into()),
138 qos,
139 retain: false,
140 }
141 }
142
143 fn len(&self) -> usize {
144 let mut len = 0;
145 len += 2 + self.topic.len() + 2 + self.message.len();
146 len
147 }
148
149 fn extract(connect_flags: u8, mut bytes: &mut Bytes) -> Result<Option<LastWill>, Error> {
150 let last_will = match connect_flags & 0b100 {
151 0 if (connect_flags & 0b0011_1000) != 0 => {
152 return Err(Error::IncorrectPacketFormat);
153 }
154 0 => None,
155 _ => {
156 let will_topic = read_mqtt_string(&mut bytes)?;
157 let will_message = read_mqtt_bytes(&mut bytes)?;
158 let will_qos = qos((connect_flags & 0b11000) >> 3)?;
159 Some(LastWill {
160 topic: will_topic,
161 message: will_message,
162 qos: will_qos,
163 retain: (connect_flags & 0b0010_0000) != 0,
164 })
165 }
166 };
167
168 Ok(last_will)
169 }
170
171 fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
172 let mut connect_flags = 0;
173
174 connect_flags |= 0x04 | (self.qos as u8) << 3;
175 if self.retain {
176 connect_flags |= 0x20;
177 }
178
179 write_mqtt_string(buffer, &self.topic);
180 write_mqtt_bytes(buffer, &self.message);
181 Ok(connect_flags)
182 }
183}
184
185#[derive(Debug, Clone, PartialEq)]
187pub struct Login {
188 username: String,
189 password: String,
190}
191
192impl Login {
193 pub fn new<S: Into<String>>(u: S, p: S) -> Login {
194 Login {
195 username: u.into(),
196 password: p.into(),
197 }
198 }
199
200 fn extract(connect_flags: u8, mut bytes: &mut Bytes) -> Result<Option<Login>, Error> {
201 let username = match connect_flags & 0b1000_0000 {
202 0 => String::new(),
203 _ => read_mqtt_string(&mut bytes)?,
204 };
205
206 let password = match connect_flags & 0b0100_0000 {
207 0 => String::new(),
208 _ => read_mqtt_string(&mut bytes)?,
209 };
210
211 if username.is_empty() && password.is_empty() {
212 Ok(None)
213 } else {
214 Ok(Some(Login { username, password }))
215 }
216 }
217
218 fn len(&self) -> usize {
219 let mut len = 0;
220
221 if !self.username.is_empty() {
222 len += 2 + self.username.len();
223 }
224
225 if !self.password.is_empty() {
226 len += 2 + self.password.len();
227 }
228
229 len
230 }
231
232 fn write(&self, buffer: &mut BytesMut) -> u8 {
233 let mut connect_flags = 0;
234 if !self.username.is_empty() {
235 connect_flags |= 0x80;
236 write_mqtt_string(buffer, &self.username);
237 }
238
239 if !self.password.is_empty() {
240 connect_flags |= 0x40;
241 write_mqtt_string(buffer, &self.password);
242 }
243
244 connect_flags
245 }
246}
247
248impl fmt::Debug for Connect {
249 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250 write!(
251 f,
252 "Protocol = {:?}, Keep alive = {:?}, Client id = {}, Clean session = {}",
253 self.protocol, self.keep_alive, self.client_id, self.clean_session,
254 )
255 }
256}
257
258#[cfg(test)]
259mod test {
260 use crate::*;
261 use alloc::borrow::ToOwned;
262 use alloc::vec;
263 use bytes::BytesMut;
264 use pretty_assertions::assert_eq;
265
266 #[test]
267 fn connect_stitching_works_correctlyl() {
268 let mut stream = bytes::BytesMut::new();
269 let packetstream = &[
270 0x10,
271 39, 0x00,
273 0x04,
274 b'M',
275 b'Q',
276 b'T',
277 b'T',
278 0x04, 0b1100_1110, 0x00,
281 0x0a, 0x00,
283 0x04,
284 b't',
285 b'e',
286 b's',
287 b't', 0x00,
289 0x02,
290 b'/',
291 b'a', 0x00,
293 0x07,
294 b'o',
295 b'f',
296 b'f',
297 b'l',
298 b'i',
299 b'n',
300 b'e', 0x00,
302 0x04,
303 b'r',
304 b'u',
305 b'm',
306 b'q', 0x00,
308 0x02,
309 b'm',
310 b'q', 0xDE,
312 0xAD,
313 0xBE,
314 0xEF, ];
316
317 stream.extend_from_slice(&packetstream[..]);
318 let packet = mqtt_read(&mut stream, 100).unwrap();
319 let packet = match packet {
320 Packet::Connect(connect) => connect,
321 packet => panic!("Invalid packet = {:?}", packet),
322 };
323
324 assert_eq!(
325 packet,
326 Connect {
327 protocol: Protocol::MQTT(4),
328 keep_alive: 10,
329 client_id: "test".to_owned(),
330 clean_session: true,
331 last_will: Some(LastWill::new("/a", QoS::AtLeastOnce, "offline")),
332 login: Some(Login::new("rumq", "mq"))
333 }
334 );
335 }
336
337 #[test]
338 fn connack_stitching_works_correctly() {
339 let mut stream = bytes::BytesMut::new();
340 let packetstream = &[
341 0b0010_0000,
342 0x02, 0x01,
344 0x00, 0xDE,
346 0xAD,
347 0xBE,
348 0xEF, ];
350
351 stream.extend_from_slice(&packetstream[..]);
352 let packet = mqtt_read(&mut stream, 100).unwrap();
353 let packet = match packet {
354 Packet::ConnAck(packet) => packet,
355 packet => panic!("Invalid packet = {:?}", packet),
356 };
357
358 assert_eq!(
359 packet,
360 ConnAck {
361 session_present: true,
362 code: ConnectReturnCode::Accepted
363 }
364 );
365 }
366
367 #[test]
368 fn write_connect_mqtt_packet_works() {
369 let connect = Connect {
370 protocol: Protocol::MQTT(4),
371 keep_alive: 10,
372 client_id: "test".to_owned(),
373 clean_session: true,
374 last_will: Some(LastWill::new("/a", QoS::AtLeastOnce, "offline")),
375 login: Some(Login::new("rust", "mq")),
376 };
377
378 let mut buf = BytesMut::new();
379 connect.write(&mut buf).unwrap();
380
381 assert_eq!(
382 buf,
383 vec![
384 0x10,
385 39,
386 0x00,
387 0x04,
388 b'M',
389 b'Q',
390 b'T',
391 b'T',
392 0x04,
393 0b1100_1110, 0x00,
395 0x0a, 0x00,
397 0x04,
398 b't',
399 b'e',
400 b's',
401 b't', 0x00,
403 0x02,
404 b'/',
405 b'a', 0x00,
407 0x07,
408 b'o',
409 b'f',
410 b'f',
411 b'l',
412 b'i',
413 b'n',
414 b'e', 0x00,
416 0x04,
417 b'r',
418 b'u',
419 b's',
420 b't', 0x00,
422 0x02,
423 b'm',
424 b'q' ]
426 );
427 }
428}