1use crate::encoding::{decode_binary, decode_string, encode_binary, encode_string};
2use crate::error::{MqttError, Result};
3use crate::flags::ConnectFlags;
4use crate::packet::{FixedHeader, MqttPacket, PacketType};
5use crate::prelude::{format, String, ToString, Vec};
6use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
7use crate::types::{ConnectOptions, WillMessage, WillProperties};
8use crate::QoS;
9use bytes::{Buf, BufMut, Bytes};
10
11const PROTOCOL_NAME: &str = "MQTT";
12const PROTOCOL_VERSION_V5: u8 = 5;
13const PROTOCOL_VERSION_V311: u8 = 4;
14
15#[derive(Debug, Clone)]
17pub struct ConnectPacket {
18 pub protocol_version: u8,
20 pub clean_start: bool,
22 pub keep_alive: u16,
24 pub client_id: String,
26 pub username: Option<String>,
28 pub password: Option<Vec<u8>>,
30 pub will: Option<WillMessage>,
32 pub properties: Properties,
34 pub will_properties: Properties,
36}
37
38impl ConnectPacket {
39 #[must_use]
41 pub fn new(options: ConnectOptions) -> Self {
42 let properties = Self::build_connect_properties(&options.properties);
43 let will_properties = options
44 .will
45 .as_ref()
46 .map_or_else(Properties::default, |will| {
47 Self::build_will_properties(&will.properties)
48 });
49
50 Self {
51 protocol_version: PROTOCOL_VERSION_V5,
52 clean_start: options.clean_start,
53 keep_alive: Self::calculate_keep_alive(options.keep_alive),
54 client_id: options.client_id,
55 username: options.username,
56 password: options.password,
57 will: options.will,
58 properties,
59 will_properties,
60 }
61 }
62
63 fn build_connect_properties(props: &crate::types::ConnectProperties) -> Properties {
65 let mut properties = Properties::default();
66
67 if let Some(val) = props.session_expiry_interval {
68 let _ = properties.add(
69 PropertyId::SessionExpiryInterval,
70 PropertyValue::FourByteInteger(val),
71 );
72 }
73 if let Some(val) = props.receive_maximum {
74 let _ = properties.add(
75 PropertyId::ReceiveMaximum,
76 PropertyValue::TwoByteInteger(val),
77 );
78 }
79 if let Some(val) = props.maximum_packet_size {
80 let _ = properties.add(
81 PropertyId::MaximumPacketSize,
82 PropertyValue::FourByteInteger(val),
83 );
84 }
85 if let Some(val) = props.topic_alias_maximum {
86 let _ = properties.add(
87 PropertyId::TopicAliasMaximum,
88 PropertyValue::TwoByteInteger(val),
89 );
90 }
91 if let Some(val) = props.request_response_information {
92 let _ = properties.add(
93 PropertyId::RequestResponseInformation,
94 PropertyValue::Byte(u8::from(val)),
95 );
96 }
97 if let Some(val) = props.request_problem_information {
98 let _ = properties.add(
99 PropertyId::RequestProblemInformation,
100 PropertyValue::Byte(u8::from(val)),
101 );
102 }
103 if let Some(val) = &props.authentication_method {
104 let _ = properties.add(
105 PropertyId::AuthenticationMethod,
106 PropertyValue::Utf8String(val.clone()),
107 );
108 }
109 if let Some(val) = &props.authentication_data {
110 let _ = properties.add(
111 PropertyId::AuthenticationData,
112 PropertyValue::BinaryData(val.clone().into()),
113 );
114 }
115 for (key, value) in &props.user_properties {
116 let _ = properties.add(
117 PropertyId::UserProperty,
118 PropertyValue::Utf8StringPair(key.clone(), value.clone()),
119 );
120 }
121
122 properties
123 }
124
125 fn build_will_properties(will_props: &crate::types::WillProperties) -> Properties {
127 let mut properties = Properties::default();
128
129 if let Some(val) = will_props.will_delay_interval {
130 let _ = properties.add(
131 PropertyId::WillDelayInterval,
132 PropertyValue::FourByteInteger(val),
133 );
134 }
135 if let Some(val) = will_props.payload_format_indicator {
136 let _ = properties.add(
137 PropertyId::PayloadFormatIndicator,
138 PropertyValue::Byte(u8::from(val)),
139 );
140 }
141 if let Some(val) = will_props.message_expiry_interval {
142 let _ = properties.add(
143 PropertyId::MessageExpiryInterval,
144 PropertyValue::FourByteInteger(val),
145 );
146 }
147 if let Some(val) = &will_props.content_type {
148 let _ = properties.add(
149 PropertyId::ContentType,
150 PropertyValue::Utf8String(val.clone()),
151 );
152 }
153 if let Some(val) = &will_props.response_topic {
154 let _ = properties.add(
155 PropertyId::ResponseTopic,
156 PropertyValue::Utf8String(val.clone()),
157 );
158 }
159 if let Some(val) = &will_props.correlation_data {
160 let _ = properties.add(
161 PropertyId::CorrelationData,
162 PropertyValue::BinaryData(val.clone().into()),
163 );
164 }
165 for (key, value) in &will_props.user_properties {
166 let _ = properties.add(
167 PropertyId::UserProperty,
168 PropertyValue::Utf8StringPair(key.clone(), value.clone()),
169 );
170 }
171
172 properties
173 }
174
175 fn calculate_keep_alive(keep_alive: crate::time::Duration) -> u16 {
177 keep_alive
178 .as_secs()
179 .min(u64::from(u16::MAX))
180 .try_into()
181 .unwrap_or(u16::MAX)
182 }
183
184 #[must_use]
186 pub fn new_v311(options: ConnectOptions) -> Self {
187 Self {
188 protocol_version: PROTOCOL_VERSION_V311,
189 clean_start: options.clean_start,
190 keep_alive: Self::calculate_keep_alive(options.keep_alive),
191 client_id: options.client_id,
192 username: options.username,
193 password: options.password,
194 will: options.will,
195 properties: Properties::default(),
196 will_properties: Properties::default(),
197 }
198 }
199
200 fn connect_flags(&self) -> u8 {
202 let mut flags = 0u8;
203
204 if self.clean_start {
205 flags |= ConnectFlags::CleanStart as u8;
206 }
207
208 if let Some(ref will) = self.will {
209 flags |= ConnectFlags::WillFlag as u8;
210 flags = ConnectFlags::with_will_qos(flags, will.qos as u8);
211 if will.retain {
212 flags |= ConnectFlags::WillRetain as u8;
213 }
214 }
215
216 if self.username.is_some() {
217 flags |= ConnectFlags::UsernameFlag as u8;
218 }
219
220 if self.password.is_some() {
221 flags |= ConnectFlags::PasswordFlag as u8;
222 }
223
224 flags
225 }
226}
227
228impl MqttPacket for ConnectPacket {
229 fn packet_type(&self) -> PacketType {
230 PacketType::Connect
231 }
232
233 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
234 encode_string(buf, PROTOCOL_NAME)?;
236 buf.put_u8(self.protocol_version);
237 buf.put_u8(self.connect_flags());
238 buf.put_u16(self.keep_alive);
239
240 if self.protocol_version == PROTOCOL_VERSION_V5 {
242 self.properties.encode(buf)?;
243 }
244
245 encode_string(buf, &self.client_id)?;
247
248 if let Some(ref will) = self.will {
250 if self.protocol_version == PROTOCOL_VERSION_V5 {
251 self.will_properties.encode(buf)?;
252 }
253 encode_string(buf, &will.topic)?;
254 encode_binary(buf, &will.payload)?;
255 }
256
257 if let Some(ref username) = self.username {
259 encode_string(buf, username)?;
260 }
261
262 if let Some(ref password) = self.password {
264 encode_binary(buf, password)?;
265 }
266
267 Ok(())
268 }
269
270 fn decode_body<B: Buf>(buf: &mut B, _fixed_header: &FixedHeader) -> Result<Self> {
271 let protocol_version = Self::decode_protocol_header(buf)?;
273 let (flags, keep_alive) = Self::decode_connect_flags_and_keepalive(buf)?;
274
275 let properties = if protocol_version == PROTOCOL_VERSION_V5 {
277 Properties::decode(buf)?
278 } else {
279 Properties::default()
280 };
281
282 let client_id = decode_string(buf)?;
284 let (will, will_properties) = Self::decode_will(buf, &flags, protocol_version)?;
285 let (username, password) = Self::decode_credentials(buf, &flags)?;
286
287 Ok(Self {
288 protocol_version,
289 clean_start: flags.clean_start,
290 keep_alive,
291 client_id,
292 username,
293 password: password.map(|p| p.to_vec()),
294 will,
295 properties,
296 will_properties,
297 })
298 }
299}
300
301struct DecodedConnectFlags {
303 clean_start: bool,
304 will_flag: bool,
305 will_qos: u8,
306 will_retain: bool,
307 credentials: CredentialFlags,
308}
309
310struct CredentialFlags {
311 username_flag: bool,
312 password_flag: bool,
313}
314
315impl ConnectPacket {
316 fn decode_protocol_header<B: Buf>(buf: &mut B) -> Result<u8> {
318 let protocol_name = decode_string(buf)?;
320 if protocol_name != PROTOCOL_NAME {
321 return Err(MqttError::ProtocolError(format!(
322 "Invalid protocol name: {protocol_name}"
323 )));
324 }
325
326 if !buf.has_remaining() {
328 return Err(MqttError::MalformedPacket(
329 "Missing protocol version".to_string(),
330 ));
331 }
332 let protocol_version = buf.get_u8();
333
334 Ok(protocol_version)
335 }
336
337 fn decode_connect_flags_and_keepalive<B: Buf>(
339 buf: &mut B,
340 ) -> Result<(DecodedConnectFlags, u16)> {
341 if !buf.has_remaining() {
343 return Err(MqttError::MalformedPacket(
344 "Missing connect flags".to_string(),
345 ));
346 }
347 let flags = buf.get_u8();
348
349 let decomposed_flags = ConnectFlags::decompose(flags);
351
352 if decomposed_flags.contains(&ConnectFlags::Reserved) {
354 return Err(MqttError::MalformedPacket(
355 "Reserved flag bit must be 0".to_string(),
356 ));
357 }
358
359 let credentials = CredentialFlags {
360 username_flag: decomposed_flags.contains(&ConnectFlags::UsernameFlag),
361 password_flag: decomposed_flags.contains(&ConnectFlags::PasswordFlag),
362 };
363
364 let decoded_flags = DecodedConnectFlags {
365 clean_start: decomposed_flags.contains(&ConnectFlags::CleanStart),
366 will_flag: decomposed_flags.contains(&ConnectFlags::WillFlag),
367 will_qos: ConnectFlags::extract_will_qos(flags),
368 will_retain: decomposed_flags.contains(&ConnectFlags::WillRetain),
369 credentials,
370 };
371
372 if buf.remaining() < 2 {
374 return Err(MqttError::MalformedPacket("Missing keep alive".to_string()));
375 }
376 let keep_alive = buf.get_u16();
377
378 Ok((decoded_flags, keep_alive))
379 }
380
381 fn decode_will<B: Buf>(
383 buf: &mut B,
384 flags: &DecodedConnectFlags,
385 protocol_version: u8,
386 ) -> Result<(Option<WillMessage>, Properties)> {
387 if !flags.will_flag {
388 return Ok((None, Properties::default()));
389 }
390
391 let will_properties = if protocol_version == PROTOCOL_VERSION_V5 {
392 Properties::decode(buf)?
393 } else {
394 Properties::default()
395 };
396
397 let topic = decode_string(buf)?;
398 let payload = decode_binary(buf)?;
399
400 let qos = match flags.will_qos {
401 0 => QoS::AtMostOnce,
402 1 => QoS::AtLeastOnce,
403 2 => QoS::ExactlyOnce,
404 _ => return Err(MqttError::MalformedPacket("Invalid will QoS".to_string())),
405 };
406
407 let will_props = Self::properties_to_will_properties(&will_properties);
409
410 let will = WillMessage {
411 topic,
412 payload: payload.to_vec(),
413 qos,
414 retain: flags.will_retain,
415 properties: will_props,
416 };
417
418 Ok((Some(will), will_properties))
419 }
420
421 fn properties_to_will_properties(props: &Properties) -> WillProperties {
423 use crate::protocol::v5::properties::{PropertyId, PropertyValue};
424
425 let mut will_props = WillProperties::default();
426
427 if let Some(PropertyValue::FourByteInteger(delay)) =
429 props.get(PropertyId::WillDelayInterval)
430 {
431 will_props.will_delay_interval = Some(*delay);
432 }
433
434 if let Some(PropertyValue::Byte(indicator)) = props.get(PropertyId::PayloadFormatIndicator)
436 {
437 will_props.payload_format_indicator = Some(*indicator != 0);
438 }
439
440 if let Some(PropertyValue::FourByteInteger(expiry)) =
442 props.get(PropertyId::MessageExpiryInterval)
443 {
444 will_props.message_expiry_interval = Some(*expiry);
445 }
446
447 if let Some(PropertyValue::Utf8String(content_type)) = props.get(PropertyId::ContentType) {
449 will_props.content_type = Some(content_type.clone());
450 }
451
452 if let Some(PropertyValue::Utf8String(topic)) = props.get(PropertyId::ResponseTopic) {
454 will_props.response_topic = Some(topic.clone());
455 }
456
457 if let Some(PropertyValue::BinaryData(data)) = props.get(PropertyId::CorrelationData) {
459 will_props.correlation_data = Some(data.to_vec());
460 }
461
462 if let Some(values) = props.get_all(PropertyId::UserProperty) {
464 for value in values {
465 if let PropertyValue::Utf8StringPair(key, val) = value {
466 will_props.user_properties.push((key.clone(), val.clone()));
467 }
468 }
469 }
470
471 will_props
472 }
473
474 fn decode_credentials<B: Buf>(
476 buf: &mut B,
477 flags: &DecodedConnectFlags,
478 ) -> Result<(Option<String>, Option<Bytes>)> {
479 let username = if flags.credentials.username_flag {
480 Some(decode_string(buf)?)
481 } else {
482 None
483 };
484
485 let password = if flags.credentials.password_flag {
486 Some(decode_binary(buf)?)
487 } else {
488 None
489 };
490
491 if password.is_some() && username.is_none() {
493 return Err(MqttError::MalformedPacket(
494 "Password without username is not allowed".to_string(),
495 ));
496 }
497
498 Ok((username, password))
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use crate::time::Duration;
506 use bytes::BytesMut;
507
508 #[test]
509 fn test_connect_packet_basic() {
510 let options = ConnectOptions::new("test-client");
511 let packet = ConnectPacket::new(options);
512
513 assert_eq!(packet.protocol_version, PROTOCOL_VERSION_V5);
514 assert!(packet.clean_start);
515 assert_eq!(packet.keep_alive, 60);
516 assert_eq!(packet.client_id, "test-client");
517 assert!(packet.username.is_none());
518 assert!(packet.password.is_none());
519 assert!(packet.will.is_none());
520 }
521
522 #[test]
523 fn test_connect_packet_with_credentials() {
524 let options = ConnectOptions::new("test-client").with_credentials("user", b"pass");
525 let packet = ConnectPacket::new(options);
526
527 assert_eq!(packet.username, Some("user".to_string()));
528 assert_eq!(packet.password, Some(b"pass".to_vec()));
529 }
530
531 #[test]
532 fn test_connect_packet_with_will() {
533 let will = WillMessage::new("will/topic", b"will payload")
534 .with_qos(QoS::AtLeastOnce)
535 .with_retain(true);
536 let options = ConnectOptions::new("test-client").with_will(will);
537 let packet = ConnectPacket::new(options);
538
539 assert!(packet.will.is_some());
540 let will = packet.will.as_ref().unwrap();
541 assert_eq!(will.topic, "will/topic");
542 assert_eq!(will.payload, b"will payload");
543 assert_eq!(will.qos, QoS::AtLeastOnce);
544 assert!(will.retain);
545 }
546
547 #[test]
548 fn test_connect_flags() {
549 let packet = ConnectPacket::new(ConnectOptions::new("test"));
550 assert_eq!(packet.connect_flags(), 0x02); let options = ConnectOptions::new("test")
553 .with_clean_start(false)
554 .with_credentials("user", b"pass");
555 let packet = ConnectPacket::new(options);
556 assert_eq!(packet.connect_flags(), 0xC0); let will = WillMessage::new("topic", b"payload")
559 .with_qos(QoS::ExactlyOnce)
560 .with_retain(true);
561 let options = ConnectOptions::new("test").with_will(will);
562 let packet = ConnectPacket::new(options);
563 assert_eq!(packet.connect_flags(), 0x36); }
565
566 #[test]
567 fn test_connect_encode_decode_v5() {
568 let options = ConnectOptions::new("test-client-123")
569 .with_keep_alive(Duration::from_secs(120))
570 .with_credentials("testuser", b"testpass");
571 let packet = ConnectPacket::new(options);
572
573 let mut buf = BytesMut::new();
574 packet.encode(&mut buf).unwrap();
575
576 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
577 assert_eq!(fixed_header.packet_type, PacketType::Connect);
578
579 let decoded = ConnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
580 assert_eq!(decoded.protocol_version, PROTOCOL_VERSION_V5);
581 assert_eq!(decoded.client_id, "test-client-123");
582 assert_eq!(decoded.keep_alive, 120);
583 assert_eq!(decoded.username, Some("testuser".to_string()));
584 assert_eq!(decoded.password, Some(b"testpass".to_vec()));
585 }
586
587 #[test]
588 fn test_connect_encode_decode_v311() {
589 let options = ConnectOptions::new("mqtt-311-client");
590 let packet = ConnectPacket::new_v311(options);
591
592 let mut buf = BytesMut::new();
593 packet.encode(&mut buf).unwrap();
594
595 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
596 let decoded = ConnectPacket::decode_body(&mut buf, &fixed_header).unwrap();
597
598 assert_eq!(decoded.protocol_version, PROTOCOL_VERSION_V311);
599 assert_eq!(decoded.client_id, "mqtt-311-client");
600 }
601
602 #[test]
603 fn test_connect_invalid_protocol_name() {
604 let mut buf = BytesMut::new();
605 encode_string(&mut buf, "INVALID").unwrap();
606 buf.put_u8(5);
607
608 let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
609 let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
610 assert!(result.is_err());
611 }
612
613 #[test]
614 fn test_connect_invalid_protocol_version() {
615 let mut buf = BytesMut::new();
616 encode_string(&mut buf, "MQTT").unwrap();
617 buf.put_u8(99); let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
620 let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
621 assert!(result.is_err());
622 }
623
624 #[test]
625 fn test_connect_password_without_username() {
626 let mut buf = BytesMut::new();
627 encode_string(&mut buf, "MQTT").unwrap();
628 buf.put_u8(5); buf.put_u8(0x40); buf.put_u16(60); buf.put_u8(0); encode_string(&mut buf, "client").unwrap();
633 encode_binary(&mut buf, b"password").unwrap();
634
635 let fixed_header = FixedHeader::new(PacketType::Connect, 0, 0);
636 let result = ConnectPacket::decode_body(&mut buf, &fixed_header);
637 assert!(result.is_err());
638 }
639}