1use crate::{mqtt_string_eq, mqtt_string_new, MqttString};
2
3use super::{
4 len_len, length, property, qos, read_mqtt_bytes, read_mqtt_string, read_u16, read_u32, read_u8,
5 write_mqtt_bytes, write_mqtt_string, write_remaining_length, BufMut, BytesMut, Debug, Error,
6 FixedHeader, PropertyType, QoS,
7};
8use bytes::{Buf, Bytes};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct Connect {
13 pub keep_alive: u16,
15 pub client_id: MqttString,
17 pub clean_start: bool,
19 pub properties: Option<ConnectProperties>,
20}
21
22impl Connect {
23 #[allow(clippy::type_complexity)]
24 pub fn read(
25 fixed_header: FixedHeader,
26 mut bytes: Bytes,
27 ) -> Result<(Connect, Option<LastWill>, Option<Login>), Error> {
28 let variable_header_index = fixed_header.fixed_header_len;
29 bytes.advance(variable_header_index);
30
31 let protocol_name = read_mqtt_string(&mut bytes)?;
33 let protocol_level = read_u8(&mut bytes)?;
34 if !mqtt_string_eq(&protocol_name, "MQTT") {
35 return Err(Error::InvalidProtocol);
36 }
37
38 if protocol_level != 5 {
39 return Err(Error::InvalidProtocolLevel(protocol_level));
40 }
41
42 let connect_flags = read_u8(&mut bytes)?;
43 let clean_start = (connect_flags & 0b10) != 0;
44 let keep_alive = read_u16(&mut bytes)?;
45
46 let properties = ConnectProperties::read(&mut bytes)?;
47
48 let client_id = read_mqtt_string(&mut bytes)?;
49 let will = LastWill::read(connect_flags, &mut bytes)?;
50 let login = Login::read(connect_flags, &mut bytes)?;
51
52 let connect = Connect {
53 keep_alive,
54 client_id,
55 clean_start,
56 properties,
57 };
58
59 Ok((connect, will, login))
60 }
61
62 fn len(&self, will: &Option<LastWill>, l: &Option<Login>) -> usize {
63 let mut len = 2 + "MQTT".len() + 1 + 1 + 2; if let Some(p) = &self.properties {
69 let properties_len = p.len();
70 let properties_len_len = len_len(properties_len);
71 len += properties_len_len + properties_len;
72 } else {
73 len += 1;
75 }
76
77 len += 2 + self.client_id.len();
78
79 if let Some(w) = will {
81 len += w.len();
82 }
83
84 if let Some(l) = l {
86 len += l.len();
87 }
88
89 len
90 }
91
92 pub fn write(
93 &self,
94 will: &Option<LastWill>,
95 l: &Option<Login>,
96 buffer: &mut BytesMut,
97 ) -> Result<usize, Error> {
98 let len = self.len(will, l);
99
100 buffer.put_u8(0b0001_0000);
101 let count = write_remaining_length(buffer, len)?;
102 write_mqtt_string(buffer, &mqtt_string_new("MQTT"))?;
103
104 buffer.put_u8(0x05);
105 let flags_index = 1 + count + 2 + 4 + 1;
106
107 let mut connect_flags = 0;
108 if self.clean_start {
109 connect_flags |= 0x02;
110 }
111
112 buffer.put_u8(connect_flags);
113 buffer.put_u16(self.keep_alive);
114
115 match &self.properties {
116 Some(p) => p.write(buffer)?,
117 None => {
118 write_remaining_length(buffer, 0)?;
119 }
120 };
121
122 write_mqtt_string(buffer, &self.client_id)?;
123
124 if let Some(w) = will {
125 connect_flags |= w.write(buffer)?;
126 }
127
128 if let Some(l) = l {
129 connect_flags |= l.write(buffer)?;
130 }
131
132 buffer[flags_index] = connect_flags;
134 Ok(1 + count + len)
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct ConnectProperties {
140 pub session_expiry_interval: Option<u32>,
142 pub receive_maximum: Option<u16>,
144 pub max_packet_size: Option<u32>,
146 pub topic_alias_max: Option<u16>,
148 pub request_response_info: Option<u8>,
149 pub request_problem_info: Option<u8>,
150 pub user_properties: Vec<(MqttString, MqttString)>,
152 pub authentication_method: Option<MqttString>,
154 pub authentication_data: Option<Bytes>,
156}
157
158impl ConnectProperties {
159 #[must_use]
160 pub fn new() -> ConnectProperties {
161 ConnectProperties {
162 session_expiry_interval: None,
163 receive_maximum: None,
164 max_packet_size: None,
165 topic_alias_max: None,
166 request_response_info: None,
167 request_problem_info: None,
168 user_properties: Vec::new(),
169 authentication_method: None,
170 authentication_data: None,
171 }
172 }
173
174 pub fn read(bytes: &mut Bytes) -> Result<Option<ConnectProperties>, Error> {
175 let mut session_expiry_interval = None;
176 let mut receive_maximum = None;
177 let mut max_packet_size = None;
178 let mut topic_alias_max = None;
179 let mut request_response_info = None;
180 let mut request_problem_info = None;
181 let mut user_properties = Vec::new();
182 let mut authentication_method = None;
183 let mut authentication_data = None;
184
185 let (properties_len_len, properties_len) = length(bytes.iter())?;
186 bytes.advance(properties_len_len);
187 if properties_len == 0 {
188 return Ok(None);
189 }
190
191 let mut cursor = 0;
192 while cursor < properties_len {
194 let prop = read_u8(bytes)?;
195 cursor += 1;
196 match property(prop)? {
197 PropertyType::SessionExpiryInterval => {
198 session_expiry_interval = Some(read_u32(bytes)?);
199 cursor += 4;
200 }
201 PropertyType::ReceiveMaximum => {
202 receive_maximum = Some(read_u16(bytes)?);
203 cursor += 2;
204 }
205 PropertyType::MaximumPacketSize => {
206 max_packet_size = Some(read_u32(bytes)?);
207 cursor += 4;
208 }
209 PropertyType::TopicAliasMaximum => {
210 topic_alias_max = Some(read_u16(bytes)?);
211 cursor += 2;
212 }
213 PropertyType::RequestResponseInformation => {
214 request_response_info = Some(read_u8(bytes)?);
215 cursor += 1;
216 }
217 PropertyType::RequestProblemInformation => {
218 request_problem_info = Some(read_u8(bytes)?);
219 cursor += 1;
220 }
221 PropertyType::UserProperty => {
222 let key = read_mqtt_string(bytes)?;
223 let value = read_mqtt_string(bytes)?;
224 cursor += 2 + key.len() + 2 + value.len();
225 user_properties.push((key, value));
226 }
227 PropertyType::AuthenticationMethod => {
228 let method = read_mqtt_string(bytes)?;
229 cursor += 2 + method.len();
230 authentication_method = Some(method);
231 }
232 PropertyType::AuthenticationData => {
233 let data = read_mqtt_bytes(bytes)?;
234 cursor += 2 + data.len();
235 authentication_data = Some(data);
236 }
237 _ => return Err(Error::InvalidPropertyType(prop)),
238 }
239 }
240
241 Ok(Some(ConnectProperties {
242 session_expiry_interval,
243 receive_maximum,
244 max_packet_size,
245 topic_alias_max,
246 request_response_info,
247 request_problem_info,
248 user_properties,
249 authentication_method,
250 authentication_data,
251 }))
252 }
253
254 fn len(&self) -> usize {
255 let mut len = 0;
256
257 if self.session_expiry_interval.is_some() {
258 len += 1 + 4;
259 }
260
261 if self.receive_maximum.is_some() {
262 len += 1 + 2;
263 }
264
265 if self.max_packet_size.is_some() {
266 len += 1 + 4;
267 }
268
269 if self.topic_alias_max.is_some() {
270 len += 1 + 2;
271 }
272
273 if self.request_response_info.is_some() {
274 len += 1 + 1;
275 }
276
277 if self.request_problem_info.is_some() {
278 len += 1 + 1;
279 }
280
281 for (key, value) in &self.user_properties {
282 len += 1 + 2 + key.len() + 2 + value.len();
283 }
284
285 if let Some(authentication_method) = &self.authentication_method {
286 len += 1 + 2 + authentication_method.len();
287 }
288
289 if let Some(authentication_data) = &self.authentication_data {
290 len += 1 + 2 + authentication_data.len();
291 }
292
293 len
294 }
295
296 pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
297 let len = self.len();
298 write_remaining_length(buffer, len)?;
299
300 if let Some(session_expiry_interval) = self.session_expiry_interval {
301 buffer.put_u8(PropertyType::SessionExpiryInterval as u8);
302 buffer.put_u32(session_expiry_interval);
303 }
304
305 if let Some(receive_maximum) = self.receive_maximum {
306 buffer.put_u8(PropertyType::ReceiveMaximum as u8);
307 buffer.put_u16(receive_maximum);
308 }
309
310 if let Some(max_packet_size) = self.max_packet_size {
311 buffer.put_u8(PropertyType::MaximumPacketSize as u8);
312 buffer.put_u32(max_packet_size);
313 }
314
315 if let Some(topic_alias_max) = self.topic_alias_max {
316 buffer.put_u8(PropertyType::TopicAliasMaximum as u8);
317 buffer.put_u16(topic_alias_max);
318 }
319
320 if let Some(request_response_info) = self.request_response_info {
321 buffer.put_u8(PropertyType::RequestResponseInformation as u8);
322 buffer.put_u8(request_response_info);
323 }
324
325 if let Some(request_problem_info) = self.request_problem_info {
326 buffer.put_u8(PropertyType::RequestProblemInformation as u8);
327 buffer.put_u8(request_problem_info);
328 }
329
330 for (key, value) in &self.user_properties {
331 buffer.put_u8(PropertyType::UserProperty as u8);
332 write_mqtt_string(buffer, key)?;
333 write_mqtt_string(buffer, value)?;
334 }
335
336 if let Some(authentication_method) = &self.authentication_method {
337 buffer.put_u8(PropertyType::AuthenticationMethod as u8);
338 write_mqtt_string(buffer, authentication_method)?;
339 }
340
341 if let Some(authentication_data) = &self.authentication_data {
342 buffer.put_u8(PropertyType::AuthenticationData as u8);
343 write_mqtt_bytes(buffer, authentication_data)?;
344 }
345
346 Ok(())
347 }
348}
349
350impl Default for ConnectProperties {
351 fn default() -> Self {
352 Self::new()
353 }
354}
355
356#[derive(Debug, Clone, PartialEq, Eq)]
358pub struct LastWill {
359 pub topic: Bytes,
360 pub message: Bytes,
361 pub qos: QoS,
362 pub retain: bool,
363 pub properties: Option<LastWillProperties>,
364}
365
366impl LastWill {
367 fn len(&self) -> usize {
368 let mut len = 0;
369
370 if let Some(p) = &self.properties {
371 let properties_len = p.len();
372 let properties_len_len = len_len(properties_len);
373 len += properties_len_len + properties_len;
374 } else {
375 len += 1;
377 }
378
379 len += 2 + self.topic.len() + 2 + self.message.len();
380 len
381 }
382
383 pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<LastWill>, Error> {
384 let o = match connect_flags & 0b100 {
385 0 if (connect_flags & 0b0011_1000) != 0 => {
386 return Err(Error::IncorrectPacketFormat);
387 }
388 0 => None,
389 _ => {
390 let properties = LastWillProperties::read(bytes)?;
392
393 let will_topic = read_mqtt_bytes(bytes)?;
394 let will_message = read_mqtt_bytes(bytes)?;
395 let qos_num = (connect_flags & 0b11000) >> 3;
396 let will_qos = qos(qos_num).ok_or(Error::InvalidQoS(qos_num))?;
397 Some(LastWill {
398 topic: will_topic,
399 message: will_message,
400 qos: will_qos,
401 retain: (connect_flags & 0b0010_0000) != 0,
402 properties,
403 })
404 }
405 };
406
407 Ok(o)
408 }
409
410 pub fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
411 let mut connect_flags = 0;
412
413 connect_flags |= 0x04 | (self.qos as u8) << 3;
414 if self.retain {
415 connect_flags |= 0x20;
416 }
417
418 if let Some(p) = &self.properties {
419 p.write(buffer)?;
420 } else {
421 write_remaining_length(buffer, 0)?;
422 }
423
424 write_mqtt_bytes(buffer, &self.topic)?;
425 write_mqtt_bytes(buffer, &self.message)?;
426 Ok(connect_flags)
427 }
428}
429
430#[derive(Debug, Clone, PartialEq, Eq)]
431pub struct LastWillProperties {
432 pub delay_interval: Option<u32>,
433 pub payload_format_indicator: Option<u8>,
434 pub message_expiry_interval: Option<u32>,
435 pub content_type: Option<MqttString>,
436 pub response_topic: Option<MqttString>,
437 pub correlation_data: Option<Bytes>,
438 pub user_properties: Vec<(MqttString, MqttString)>,
439}
440
441impl LastWillProperties {
442 fn len(&self) -> usize {
443 let mut len = 0;
444
445 if self.delay_interval.is_some() {
446 len += 1 + 4;
447 }
448
449 if self.payload_format_indicator.is_some() {
450 len += 1 + 1;
451 }
452
453 if self.message_expiry_interval.is_some() {
454 len += 1 + 4;
455 }
456
457 if let Some(typ) = &self.content_type {
458 len += 1 + 2 + typ.len();
459 }
460
461 if let Some(topic) = &self.response_topic {
462 len += 1 + 2 + topic.len();
463 }
464
465 if let Some(data) = &self.correlation_data {
466 len += 1 + 2 + data.len();
467 }
468
469 for (key, value) in &self.user_properties {
470 len += 1 + 2 + key.len() + 2 + value.len();
471 }
472
473 len
474 }
475
476 pub fn read(bytes: &mut Bytes) -> Result<Option<LastWillProperties>, Error> {
477 let mut delay_interval = None;
478 let mut payload_format_indicator = None;
479 let mut message_expiry_interval = None;
480 let mut content_type = None;
481 let mut response_topic = None;
482 let mut correlation_data = None;
483 let mut user_properties = Vec::new();
484
485 let (properties_len_len, properties_len) = length(bytes.iter())?;
486 bytes.advance(properties_len_len);
487 if properties_len == 0 {
488 return Ok(None);
489 }
490
491 let mut cursor = 0;
492 while cursor < properties_len {
494 let prop = read_u8(bytes)?;
495 cursor += 1;
496
497 match property(prop)? {
498 PropertyType::WillDelayInterval => {
499 delay_interval = Some(read_u32(bytes)?);
500 cursor += 4;
501 }
502 PropertyType::PayloadFormatIndicator => {
503 payload_format_indicator = Some(read_u8(bytes)?);
504 cursor += 1;
505 }
506 PropertyType::MessageExpiryInterval => {
507 message_expiry_interval = Some(read_u32(bytes)?);
508 cursor += 4;
509 }
510 PropertyType::ContentType => {
511 let typ = read_mqtt_string(bytes)?;
512 cursor += 2 + typ.len();
513 content_type = Some(typ);
514 }
515 PropertyType::ResponseTopic => {
516 let topic = read_mqtt_string(bytes)?;
517 cursor += 2 + topic.len();
518 response_topic = Some(topic);
519 }
520 PropertyType::CorrelationData => {
521 let data = read_mqtt_bytes(bytes)?;
522 cursor += 2 + data.len();
523 correlation_data = Some(data);
524 }
525 PropertyType::UserProperty => {
526 let key = read_mqtt_string(bytes)?;
527 let value = read_mqtt_string(bytes)?;
528 cursor += 2 + key.len() + 2 + value.len();
529 user_properties.push((key, value));
530 }
531 _ => return Err(Error::InvalidPropertyType(prop)),
532 }
533 }
534
535 Ok(Some(LastWillProperties {
536 delay_interval,
537 payload_format_indicator,
538 message_expiry_interval,
539 content_type,
540 response_topic,
541 correlation_data,
542 user_properties,
543 }))
544 }
545
546 pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
547 let len = self.len();
548 write_remaining_length(buffer, len)?;
549
550 if let Some(delay_interval) = self.delay_interval {
551 buffer.put_u8(PropertyType::WillDelayInterval as u8);
552 buffer.put_u32(delay_interval);
553 }
554
555 if let Some(payload_format_indicator) = self.payload_format_indicator {
556 buffer.put_u8(PropertyType::PayloadFormatIndicator as u8);
557 buffer.put_u8(payload_format_indicator);
558 }
559
560 if let Some(message_expiry_interval) = self.message_expiry_interval {
561 buffer.put_u8(PropertyType::MessageExpiryInterval as u8);
562 buffer.put_u32(message_expiry_interval);
563 }
564
565 if let Some(typ) = &self.content_type {
566 buffer.put_u8(PropertyType::ContentType as u8);
567 write_mqtt_string(buffer, typ)?;
568 }
569
570 if let Some(topic) = &self.response_topic {
571 buffer.put_u8(PropertyType::ResponseTopic as u8);
572 write_mqtt_string(buffer, topic)?;
573 }
574
575 if let Some(data) = &self.correlation_data {
576 buffer.put_u8(PropertyType::CorrelationData as u8);
577 write_mqtt_bytes(buffer, data)?;
578 }
579
580 for (key, value) in &self.user_properties {
581 buffer.put_u8(PropertyType::UserProperty as u8);
582 write_mqtt_string(buffer, key)?;
583 write_mqtt_string(buffer, value)?;
584 }
585
586 Ok(())
587 }
588}
589#[derive(Debug, Clone, PartialEq, Eq)]
590pub struct Login {
591 pub username: MqttString,
592 pub password: MqttString,
593}
594
595impl Login {
596 pub fn new<U: Into<MqttString>, P: Into<MqttString>>(u: U, p: P) -> Login {
597 Login {
598 username: u.into(),
599 password: p.into(),
600 }
601 }
602
603 pub fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<Login>, Error> {
604 let username = match connect_flags & 0b1000_0000 {
605 0 => MqttString::default(),
606 _ => read_mqtt_string(bytes)?,
607 };
608
609 let password = match connect_flags & 0b0100_0000 {
610 0 => MqttString::default(),
611 _ => read_mqtt_string(bytes)?,
612 };
613
614 if username.is_empty() && password.is_empty() {
615 Ok(None)
616 } else {
617 Ok(Some(Login { username, password }))
618 }
619 }
620
621 fn len(&self) -> usize {
622 let mut len = 0;
623
624 if !self.username.is_empty() {
625 len += 2 + self.username.len();
626 }
627
628 if !self.password.is_empty() {
629 len += 2 + self.password.len();
630 }
631
632 len
633 }
634
635 pub fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
636 let mut connect_flags = 0;
637 if !self.username.is_empty() {
638 connect_flags |= 0x80;
639 write_mqtt_string(buffer, &self.username)?;
640 }
641
642 if !self.password.is_empty() {
643 connect_flags |= 0x40;
644 write_mqtt_string(buffer, &self.password)?;
645 }
646
647 Ok(connect_flags)
648 }
649}
650
651#[cfg(test)]
652mod test {
653 use crate::test::read_write_packets;
654 use crate::Packet;
655
656 use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
657 use super::*;
658 use bytes::BytesMut;
659 use pretty_assertions::assert_eq;
660
661 #[test]
662 fn length_calculation() {
663 let mut dummy_bytes = BytesMut::new();
664 let mut connect_props = ConnectProperties::new();
665 connect_props.user_properties = vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())];
668 let connect_pkt = Connect {
669 keep_alive: 5,
670 client_id: "client".into(),
671 clean_start: true,
672 properties: Some(connect_props),
673 };
674
675 let reported_size = connect_pkt.write(&None, &None, &mut dummy_bytes).unwrap();
676 let size_from_bytes = dummy_bytes.len();
677
678 assert_eq!(reported_size, size_from_bytes);
679 }
680
681 #[test]
682 fn test_write_read() {
683 read_write_packets(write_read_provider());
684 }
685
686 fn write_read_provider() -> Vec<Packet> {
687 vec![
688 Packet::Connect(
689 Connect {
690 keep_alive: 5,
691 client_id: "client".into(),
692 clean_start: true,
693 properties: None,
694 },
695 None,
696 None,
697 ),
698 Packet::Connect(
699 Connect {
700 keep_alive: 5,
701 client_id: "client".into(),
702 clean_start: true,
703 properties: Some(ConnectProperties {
704 session_expiry_interval: Some(5),
705 receive_maximum: Some(5),
706 max_packet_size: Some(5),
707 topic_alias_max: Some(5),
708 request_response_info: Some(5),
709 request_problem_info: Some(5),
710 user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
711 authentication_method: Some("method".into()),
712 authentication_data: Some(Bytes::from("data")),
713 }),
714 },
715 Some(LastWill {
716 topic: Bytes::from("topic"),
717 message: Bytes::from("message"),
718 qos: QoS::AtLeastOnce,
719 retain: true,
720 properties: Some(LastWillProperties {
721 delay_interval: Some(5),
722 payload_format_indicator: Some(5),
723 message_expiry_interval: Some(5),
724 content_type: Some("type".into()),
725 response_topic: Some("topic".into()),
726 correlation_data: Some(Bytes::from("data")),
727 user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
728 }),
729 }),
730 Some(Login {
731 username: "username".into(),
732 password: "password".into(),
733 }),
734 ),
735 ]
736 }
737}