1use crate::encoding::{decode_binary, decode_string, encode_binary, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::ConnectFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
6use crate::types::{ConnectOptions, WillMessage, WillProperties};
7use crate::QoS;
8use bytes::{Buf, BufMut, Bytes};
9
10const PROTOCOL_NAME: &str = "MQTT";
11const PROTOCOL_VERSION_V5: u8 = 5;
12const PROTOCOL_VERSION_V311: u8 = 4;
13
14#[derive(Debug, Clone)]
16pub struct ConnectPacket {
17 pub protocol_version: u8,
19 pub clean_start: bool,
21 pub keep_alive: u16,
23 pub client_id: String,
25 pub username: Option<String>,
27 pub password: Option<Vec<u8>>,
29 pub will: Option<WillMessage>,
31 pub properties: Properties,
33 pub will_properties: Properties,
35}
36
37impl ConnectPacket {
38 #[must_use]
40 pub fn new(options: ConnectOptions) -> Self {
41 let properties = Self::build_connect_properties(&options.properties);
42 let will_properties = options
43 .will
44 .as_ref()
45 .map_or_else(Properties::default, |will| {
46 Self::build_will_properties(&will.properties)
47 });
48
49 Self {
50 protocol_version: PROTOCOL_VERSION_V5,
51 clean_start: options.clean_start,
52 keep_alive: Self::calculate_keep_alive(options.keep_alive),
53 client_id: options.client_id,
54 username: options.username,
55 password: options.password,
56 will: options.will,
57 properties,
58 will_properties,
59 }
60 }
61
62 fn build_connect_properties(props: &crate::types::ConnectProperties) -> Properties {
64 let mut properties = Properties::default();
65
66 if let Some(val) = props.session_expiry_interval {
67 let _ = properties.add(
68 PropertyId::SessionExpiryInterval,
69 PropertyValue::FourByteInteger(val),
70 );
71 }
72 if let Some(val) = props.receive_maximum {
73 let _ = properties.add(
74 PropertyId::ReceiveMaximum,
75 PropertyValue::TwoByteInteger(val),
76 );
77 }
78 if let Some(val) = props.maximum_packet_size {
79 let _ = properties.add(
80 PropertyId::MaximumPacketSize,
81 PropertyValue::FourByteInteger(val),
82 );
83 }
84 if let Some(val) = props.topic_alias_maximum {
85 let _ = properties.add(
86 PropertyId::TopicAliasMaximum,
87 PropertyValue::TwoByteInteger(val),
88 );
89 }
90 if let Some(val) = props.request_response_information {
91 let _ = properties.add(
92 PropertyId::RequestResponseInformation,
93 PropertyValue::Byte(u8::from(val)),
94 );
95 }
96 if let Some(val) = props.request_problem_information {
97 let _ = properties.add(
98 PropertyId::RequestProblemInformation,
99 PropertyValue::Byte(u8::from(val)),
100 );
101 }
102 if let Some(val) = &props.authentication_method {
103 let _ = properties.add(
104 PropertyId::AuthenticationMethod,
105 PropertyValue::Utf8String(val.clone()),
106 );
107 }
108 if let Some(val) = &props.authentication_data {
109 let _ = properties.add(
110 PropertyId::AuthenticationData,
111 PropertyValue::BinaryData(val.clone().into()),
112 );
113 }
114 for (key, value) in &props.user_properties {
115 let _ = properties.add(
116 PropertyId::UserProperty,
117 PropertyValue::Utf8StringPair(key.clone(), value.clone()),
118 );
119 }
120
121 properties
122 }
123
124 fn build_will_properties(will_props: &crate::types::WillProperties) -> Properties {
126 let mut properties = Properties::default();
127
128 if let Some(val) = will_props.will_delay_interval {
129 let _ = properties.add(
130 PropertyId::WillDelayInterval,
131 PropertyValue::FourByteInteger(val),
132 );
133 }
134 if let Some(val) = will_props.payload_format_indicator {
135 let _ = properties.add(
136 PropertyId::PayloadFormatIndicator,
137 PropertyValue::Byte(u8::from(val)),
138 );
139 }
140 if let Some(val) = will_props.message_expiry_interval {
141 let _ = properties.add(
142 PropertyId::MessageExpiryInterval,
143 PropertyValue::FourByteInteger(val),
144 );
145 }
146 if let Some(val) = &will_props.content_type {
147 let _ = properties.add(
148 PropertyId::ContentType,
149 PropertyValue::Utf8String(val.clone()),
150 );
151 }
152 if let Some(val) = &will_props.response_topic {
153 let _ = properties.add(
154 PropertyId::ResponseTopic,
155 PropertyValue::Utf8String(val.clone()),
156 );
157 }
158 if let Some(val) = &will_props.correlation_data {
159 let _ = properties.add(
160 PropertyId::CorrelationData,
161 PropertyValue::BinaryData(val.clone().into()),
162 );
163 }
164 for (key, value) in &will_props.user_properties {
165 let _ = properties.add(
166 PropertyId::UserProperty,
167 PropertyValue::Utf8StringPair(key.clone(), value.clone()),
168 );
169 }
170
171 properties
172 }
173
174 fn calculate_keep_alive(keep_alive: crate::time::Duration) -> u16 {
176 keep_alive
177 .as_secs()
178 .min(u64::from(u16::MAX))
179 .try_into()
180 .unwrap_or(u16::MAX)
181 }
182
183 #[must_use]
185 pub fn new_v311(options: ConnectOptions) -> Self {
186 Self {
187 protocol_version: PROTOCOL_VERSION_V311,
188 clean_start: options.clean_start,
189 keep_alive: Self::calculate_keep_alive(options.keep_alive),
190 client_id: options.client_id,
191 username: options.username,
192 password: options.password,
193 will: options.will,
194 properties: Properties::default(),
195 will_properties: Properties::default(),
196 }
197 }
198
199 fn connect_flags(&self) -> u8 {
201 let mut flags = 0u8;
202
203 if self.clean_start {
204 flags |= ConnectFlags::CleanStart as u8;
205 }
206
207 if let Some(ref will) = self.will {
208 flags |= ConnectFlags::WillFlag as u8;
209 flags = ConnectFlags::with_will_qos(flags, will.qos as u8);
210 if will.retain {
211 flags |= ConnectFlags::WillRetain as u8;
212 }
213 }
214
215 if self.username.is_some() {
216 flags |= ConnectFlags::UsernameFlag as u8;
217 }
218
219 if self.password.is_some() {
220 flags |= ConnectFlags::PasswordFlag as u8;
221 }
222
223 flags
224 }
225}
226
227impl MqttPacket for ConnectPacket {
228 fn packet_type(&self) -> PacketType {
229 PacketType::Connect
230 }
231
232 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
233 encode_string(buf, PROTOCOL_NAME)?;
235 buf.put_u8(self.protocol_version);
236 buf.put_u8(self.connect_flags());
237 buf.put_u16(self.keep_alive);
238
239 if self.protocol_version == PROTOCOL_VERSION_V5 {
241 self.properties.encode(buf)?;
242 }
243
244 encode_string(buf, &self.client_id)?;
246
247 if let Some(ref will) = self.will {
249 if self.protocol_version == PROTOCOL_VERSION_V5 {
250 self.will_properties.encode(buf)?;
251 }
252 encode_string(buf, &will.topic)?;
253 encode_binary(buf, &will.payload)?;
254 }
255
256 if let Some(ref username) = self.username {
258 encode_string(buf, username)?;
259 }
260
261 if let Some(ref password) = self.password {
263 encode_binary(buf, password)?;
264 }
265
266 Ok(())
267 }
268
269 fn decode_body<B: Buf>(buf: &mut B, _fixed_header: &FixedHeader) -> Result<Self> {
270 let protocol_version = Self::decode_protocol_header(buf)?;
272 let (flags, keep_alive) = Self::decode_connect_flags_and_keepalive(buf)?;
273
274 let properties = if protocol_version == PROTOCOL_VERSION_V5 {
276 Properties::decode(buf)?
277 } else {
278 Properties::default()
279 };
280
281 let client_id = decode_string(buf)?;
283 let (will, will_properties) = Self::decode_will(buf, &flags, protocol_version)?;
284 let (username, password) = Self::decode_credentials(buf, &flags)?;
285
286 Ok(Self {
287 protocol_version,
288 clean_start: flags.clean_start,
289 keep_alive,
290 client_id,
291 username,
292 password: password.map(|p| p.to_vec()),
293 will,
294 properties,
295 will_properties,
296 })
297 }
298}
299
300struct DecodedConnectFlags {
302 clean_start: bool,
303 will_flag: bool,
304 will_qos: u8,
305 will_retain: bool,
306 credentials: CredentialFlags,
307}
308
309struct CredentialFlags {
310 username_flag: bool,
311 password_flag: bool,
312}
313
314impl ConnectPacket {
315 fn decode_protocol_header<B: Buf>(buf: &mut B) -> Result<u8> {
317 let protocol_name = decode_string(buf)?;
319 if protocol_name != PROTOCOL_NAME {
320 return Err(MqttError::ProtocolError(format!(
321 "Invalid protocol name: {protocol_name}"
322 )));
323 }
324
325 if !buf.has_remaining() {
327 return Err(MqttError::MalformedPacket(
328 "Missing protocol version".to_string(),
329 ));
330 }
331 let protocol_version = buf.get_u8();
332
333 Ok(protocol_version)
334 }
335
336 fn decode_connect_flags_and_keepalive<B: Buf>(
338 buf: &mut B,
339 ) -> Result<(DecodedConnectFlags, u16)> {
340 if !buf.has_remaining() {
342 return Err(MqttError::MalformedPacket(
343 "Missing connect flags".to_string(),
344 ));
345 }
346 let flags = buf.get_u8();
347
348 let decomposed_flags = ConnectFlags::decompose(flags);
350
351 if decomposed_flags.contains(&ConnectFlags::Reserved) {
353 return Err(MqttError::MalformedPacket(
354 "Reserved flag bit must be 0".to_string(),
355 ));
356 }
357
358 let credentials = CredentialFlags {
359 username_flag: decomposed_flags.contains(&ConnectFlags::UsernameFlag),
360 password_flag: decomposed_flags.contains(&ConnectFlags::PasswordFlag),
361 };
362
363 let decoded_flags = DecodedConnectFlags {
364 clean_start: decomposed_flags.contains(&ConnectFlags::CleanStart),
365 will_flag: decomposed_flags.contains(&ConnectFlags::WillFlag),
366 will_qos: ConnectFlags::extract_will_qos(flags),
367 will_retain: decomposed_flags.contains(&ConnectFlags::WillRetain),
368 credentials,
369 };
370
371 if buf.remaining() < 2 {
373 return Err(MqttError::MalformedPacket("Missing keep alive".to_string()));
374 }
375 let keep_alive = buf.get_u16();
376
377 Ok((decoded_flags, keep_alive))
378 }
379
380 fn decode_will<B: Buf>(
382 buf: &mut B,
383 flags: &DecodedConnectFlags,
384 protocol_version: u8,
385 ) -> Result<(Option<WillMessage>, Properties)> {
386 if !flags.will_flag {
387 return Ok((None, Properties::default()));
388 }
389
390 let will_properties = if protocol_version == PROTOCOL_VERSION_V5 {
391 Properties::decode(buf)?
392 } else {
393 Properties::default()
394 };
395
396 let topic = decode_string(buf)?;
397 let payload = decode_binary(buf)?;
398
399 let qos = match flags.will_qos {
400 0 => QoS::AtMostOnce,
401 1 => QoS::AtLeastOnce,
402 2 => QoS::ExactlyOnce,
403 _ => return Err(MqttError::MalformedPacket("Invalid will QoS".to_string())),
404 };
405
406 let will_props = Self::properties_to_will_properties(&will_properties);
408
409 let will = WillMessage {
410 topic,
411 payload: payload.to_vec(),
412 qos,
413 retain: flags.will_retain,
414 properties: will_props,
415 };
416
417 Ok((Some(will), will_properties))
418 }
419
420 fn properties_to_will_properties(props: &Properties) -> WillProperties {
422 use crate::protocol::v5::properties::{PropertyId, PropertyValue};
423
424 let mut will_props = WillProperties::default();
425
426 if let Some(PropertyValue::FourByteInteger(delay)) =
428 props.get(PropertyId::WillDelayInterval)
429 {
430 will_props.will_delay_interval = Some(*delay);
431 }
432
433 if let Some(PropertyValue::Byte(indicator)) = props.get(PropertyId::PayloadFormatIndicator)
435 {
436 will_props.payload_format_indicator = Some(*indicator != 0);
437 }
438
439 if let Some(PropertyValue::FourByteInteger(expiry)) =
441 props.get(PropertyId::MessageExpiryInterval)
442 {
443 will_props.message_expiry_interval = Some(*expiry);
444 }
445
446 if let Some(PropertyValue::Utf8String(content_type)) = props.get(PropertyId::ContentType) {
448 will_props.content_type = Some(content_type.clone());
449 }
450
451 if let Some(PropertyValue::Utf8String(topic)) = props.get(PropertyId::ResponseTopic) {
453 will_props.response_topic = Some(topic.clone());
454 }
455
456 if let Some(PropertyValue::BinaryData(data)) = props.get(PropertyId::CorrelationData) {
458 will_props.correlation_data = Some(data.to_vec());
459 }
460
461 if let Some(values) = props.get_all(PropertyId::UserProperty) {
463 for value in values {
464 if let PropertyValue::Utf8StringPair(key, val) = value {
465 will_props.user_properties.push((key.clone(), val.clone()));
466 }
467 }
468 }
469
470 will_props
471 }
472
473 fn decode_credentials<B: Buf>(
475 buf: &mut B,
476 flags: &DecodedConnectFlags,
477 ) -> Result<(Option<String>, Option<Bytes>)> {
478 let username = if flags.credentials.username_flag {
479 Some(decode_string(buf)?)
480 } else {
481 None
482 };
483
484 let password = if flags.credentials.password_flag {
485 Some(decode_binary(buf)?)
486 } else {
487 None
488 };
489
490 if password.is_some() && username.is_none() {
492 return Err(MqttError::MalformedPacket(
493 "Password without username is not allowed".to_string(),
494 ));
495 }
496
497 Ok((username, password))
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504 use crate::time::Duration;
505 use bytes::BytesMut;
506
507 #[test]
508 fn test_connect_packet_basic() {
509 let options = ConnectOptions::new("test-client");
510 let packet = ConnectPacket::new(options);
511
512 assert_eq!(packet.protocol_version, PROTOCOL_VERSION_V5);
513 assert!(packet.clean_start);
514 assert_eq!(packet.keep_alive, 60);
515 assert_eq!(packet.client_id, "test-client");
516 assert!(packet.username.is_none());
517 assert!(packet.password.is_none());
518 assert!(packet.will.is_none());
519 }
520
521 #[test]
522 fn test_connect_packet_with_credentials() {
523 let options = ConnectOptions::new("test-client").with_credentials("user", b"pass");
524 let packet = ConnectPacket::new(options);
525
526 assert_eq!(packet.username, Some("user".to_string()));
527 assert_eq!(packet.password, Some(b"pass".to_vec()));
528 }
529
530 #[test]
531 fn test_connect_packet_with_will() {
532 let will = WillMessage::new("will/topic", b"will payload")
533 .with_qos(QoS::AtLeastOnce)
534 .with_retain(true);
535 let options = ConnectOptions::new("test-client").with_will(will);
536 let packet = ConnectPacket::new(options);
537
538 assert!(packet.will.is_some());
539 let will = packet.will.as_ref().unwrap();
540 assert_eq!(will.topic, "will/topic");
541 assert_eq!(will.payload, b"will payload");
542 assert_eq!(will.qos, QoS::AtLeastOnce);
543 assert!(will.retain);
544 }
545
546 #[test]
547 fn test_connect_flags() {
548 let packet = ConnectPacket::new(ConnectOptions::new("test"));
549 assert_eq!(packet.connect_flags(), 0x02); let options = ConnectOptions::new("test")
552 .with_clean_start(false)
553 .with_credentials("user", b"pass");
554 let packet = ConnectPacket::new(options);
555 assert_eq!(packet.connect_flags(), 0xC0); let will = WillMessage::new("topic", b"payload")
558 .with_qos(QoS::ExactlyOnce)
559 .with_retain(true);
560 let options = ConnectOptions::new("test").with_will(will);
561 let packet = ConnectPacket::new(options);
562 assert_eq!(packet.connect_flags(), 0x36); }
564
565 #[test]
566 fn test_connect_encode_decode_v5() {
567 let options = ConnectOptions::new("test-client-123")
568 .with_keep_alive(Duration::from_secs(120))
569 .with_credentials("testuser", b"testpass");
570 let packet = ConnectPacket::new(options);
571
572 let mut buf = BytesMut::new();
573 packet.encode(&mut buf).unwrap();
574
575 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
576 assert_eq!(fixed_header.packet_type, PacketType::Connect);
577
578 let decoded = ConnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
579 assert_eq!(decoded.protocol_version, PROTOCOL_VERSION_V5);
580 assert_eq!(decoded.client_id, "test-client-123");
581 assert_eq!(decoded.keep_alive, 120);
582 assert_eq!(decoded.username, Some("testuser".to_string()));
583 assert_eq!(decoded.password, Some(b"testpass".to_vec()));
584 }
585
586 #[test]
587 fn test_connect_encode_decode_v311() {
588 let options = ConnectOptions::new("mqtt-311-client");
589 let packet = ConnectPacket::new_v311(options);
590
591 let mut buf = BytesMut::new();
592 packet.encode(&mut buf).unwrap();
593
594 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
595 let decoded = ConnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
596
597 assert_eq!(decoded.protocol_version, PROTOCOL_VERSION_V311);
598 assert_eq!(decoded.client_id, "mqtt-311-client");
599 }
600
601 #[test]
602 fn test_connect_invalid_protocol_name() {
603 let mut buf = BytesMut::new();
604 encode_string(&mut buf, "INVALID").unwrap();
605 buf.put_u8(5);
606
607 let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
608 let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
609 assert!(result.is_err());
610 }
611
612 #[test]
613 fn test_connect_invalid_protocol_version() {
614 let mut buf = BytesMut::new();
615 encode_string(&mut buf, "MQTT").unwrap();
616 buf.put_u8(99); let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
619 let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
620 assert!(result.is_err());
621 }
622
623 #[test]
624 fn test_connect_password_without_username() {
625 let mut buf = BytesMut::new();
626 encode_string(&mut buf, "MQTT").unwrap();
627 buf.put_u8(5); buf.put_u8(0x40); buf.put_u16(60); buf.put_u8(0); encode_string(&mut buf, "client").unwrap();
632 encode_binary(&mut buf, b"password").unwrap();
633
634 let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
635 let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
636 assert!(result.is_err());
637 }
638}