1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::packet::{FixedHeader, MqttPacket, PacketType};
4use crate::prelude::{format, String, ToString, Vec};
5use crate::protocol::v5::properties::Properties;
6use crate::types::ProtocolVersion;
7use crate::QoS;
8use bebytes::BeBytes;
9use bytes::{Buf, BufMut};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct SubscriptionOptions {
14 pub qos: QoS,
16 pub no_local: bool,
18 pub retain_as_published: bool,
20 pub retain_handling: RetainHandling,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[repr(u8)]
27pub enum RetainHandling {
28 SendAtSubscribe = 0,
30 SendAtSubscribeIfNew = 1,
32 DoNotSend = 2,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
40pub struct SubscriptionOptionsBits {
41 #[bits(2)]
43 pub reserved_bits: u8,
44 #[bits(2)]
46 pub retain_handling: u8,
47 #[bits(1)]
49 pub retain_as_published: u8,
50 #[bits(1)]
52 pub no_local: u8,
53 #[bits(2)]
55 pub qos: u8,
56}
57
58impl SubscriptionOptionsBits {
59 #[must_use]
62 pub fn from_options(options: &SubscriptionOptions) -> Self {
63 Self {
64 reserved_bits: 0,
65 retain_handling: options.retain_handling as u8,
66 retain_as_published: u8::from(options.retain_as_published),
67 no_local: u8::from(options.no_local),
68 qos: options.qos as u8,
69 }
70 }
71
72 pub fn to_options(&self) -> Result<SubscriptionOptions> {
79 if self.reserved_bits != 0 {
81 return Err(MqttError::MalformedPacket(
82 "Reserved bits in subscription options must be 0".to_string(),
83 ));
84 }
85
86 let qos = match self.qos {
88 0 => QoS::AtMostOnce,
89 1 => QoS::AtLeastOnce,
90 2 => QoS::ExactlyOnce,
91 _ => {
92 return Err(MqttError::MalformedPacket(format!(
93 "Invalid QoS value in subscription options: {}",
94 self.qos
95 )))
96 }
97 };
98
99 let retain_handling = match self.retain_handling {
101 0 => RetainHandling::SendAtSubscribe,
102 1 => RetainHandling::SendAtSubscribeIfNew,
103 2 => RetainHandling::DoNotSend,
104 _ => {
105 return Err(MqttError::MalformedPacket(format!(
106 "Invalid retain handling value: {}",
107 self.retain_handling
108 )))
109 }
110 };
111
112 Ok(SubscriptionOptions {
113 qos,
114 no_local: self.no_local != 0,
115 retain_as_published: self.retain_as_published != 0,
116 retain_handling,
117 })
118 }
119}
120
121impl Default for SubscriptionOptions {
122 fn default() -> Self {
123 Self {
124 qos: QoS::AtMostOnce,
125 no_local: false,
126 retain_as_published: false,
127 retain_handling: RetainHandling::SendAtSubscribe,
128 }
129 }
130}
131
132impl SubscriptionOptions {
133 #[must_use]
135 pub fn new(qos: QoS) -> Self {
136 Self {
137 qos,
138 ..Default::default()
139 }
140 }
141
142 #[must_use]
144 pub fn with_qos(mut self, qos: QoS) -> Self {
145 self.qos = qos;
146 self
147 }
148
149 #[must_use]
152 pub fn encode(&self) -> u8 {
153 let mut byte = self.qos as u8;
154
155 if self.no_local {
156 byte |= 0x04;
157 }
158
159 if self.retain_as_published {
160 byte |= 0x08;
161 }
162
163 byte |= (self.retain_handling as u8) << 4;
164
165 byte
166 }
167
168 #[must_use]
171 pub fn encode_with_bebytes(&self) -> u8 {
172 let bits = SubscriptionOptionsBits::from_options(self);
173 bits.to_be_bytes()[0]
174 }
175
176 pub fn decode(byte: u8) -> Result<Self> {
183 let qos_val = byte & crate::constants::subscription::QOS_MASK;
184 let qos = match qos_val {
185 0 => QoS::AtMostOnce,
186 1 => QoS::AtLeastOnce,
187 2 => QoS::ExactlyOnce,
188 _ => {
189 return Err(MqttError::MalformedPacket(format!(
190 "Invalid QoS value in subscription options: {qos_val}"
191 )))
192 }
193 };
194
195 let no_local = (byte & crate::constants::subscription::NO_LOCAL_MASK) != 0;
196 let retain_as_published =
197 (byte & crate::constants::subscription::RETAIN_AS_PUBLISHED_MASK) != 0;
198
199 let retain_handling_val = (byte >> crate::constants::subscription::RETAIN_HANDLING_SHIFT)
200 & crate::constants::subscription::QOS_MASK;
201 let retain_handling = match retain_handling_val {
202 0 => RetainHandling::SendAtSubscribe,
203 1 => RetainHandling::SendAtSubscribeIfNew,
204 2 => RetainHandling::DoNotSend,
205 _ => {
206 return Err(MqttError::MalformedPacket(format!(
207 "Invalid retain handling value: {retain_handling_val}"
208 )))
209 }
210 };
211
212 if (byte & crate::constants::subscription::RESERVED_BITS_MASK) != 0 {
214 return Err(MqttError::MalformedPacket(
215 "Reserved bits in subscription options must be 0".to_string(),
216 ));
217 }
218
219 Ok(Self {
220 qos,
221 no_local,
222 retain_as_published,
223 retain_handling,
224 })
225 }
226
227 pub fn decode_with_bebytes(byte: u8) -> Result<Self> {
234 let (bits, _consumed) =
235 SubscriptionOptionsBits::try_from_be_bytes(&[byte]).map_err(|e| {
236 MqttError::MalformedPacket(format!("Invalid subscription options byte: {e}"))
237 })?;
238
239 bits.to_options()
240 }
241}
242
243#[derive(Debug, Clone, PartialEq, Eq)]
245pub struct TopicFilter {
246 pub filter: String,
248 pub options: SubscriptionOptions,
250}
251
252impl TopicFilter {
253 #[must_use]
255 pub fn new(filter: impl Into<String>, qos: QoS) -> Self {
256 Self {
257 filter: filter.into(),
258 options: SubscriptionOptions::new(qos),
259 }
260 }
261
262 #[must_use]
264 pub fn with_options(filter: impl Into<String>, options: SubscriptionOptions) -> Self {
265 Self {
266 filter: filter.into(),
267 options,
268 }
269 }
270}
271
272#[derive(Debug, Clone)]
274pub struct SubscribePacket {
275 pub packet_id: u16,
277 pub filters: Vec<TopicFilter>,
279 pub properties: Properties,
281 pub protocol_version: u8,
283}
284
285impl SubscribePacket {
286 #[must_use]
288 pub fn new(packet_id: u16) -> Self {
289 Self {
290 packet_id,
291 filters: Vec::new(),
292 properties: Properties::default(),
293 protocol_version: 5,
294 }
295 }
296
297 #[must_use]
299 pub fn new_v311(packet_id: u16) -> Self {
300 Self {
301 packet_id,
302 filters: Vec::new(),
303 properties: Properties::default(),
304 protocol_version: 4,
305 }
306 }
307
308 #[must_use]
310 pub fn add_filter(mut self, filter: impl Into<String>, qos: QoS) -> Self {
311 self.filters.push(TopicFilter::new(filter, qos));
312 self
313 }
314
315 #[must_use]
317 pub fn add_filter_with_options(mut self, filter: TopicFilter) -> Self {
318 self.filters.push(filter);
319 self
320 }
321
322 #[must_use]
324 pub fn with_subscription_identifier(mut self, id: u32) -> Self {
325 self.properties.set_subscription_identifier(id);
326 self
327 }
328
329 #[must_use]
331 pub fn with_user_property(mut self, key: String, value: String) -> Self {
332 self.properties.add_user_property(key, value);
333 self
334 }
335}
336
337impl MqttPacket for SubscribePacket {
338 fn packet_type(&self) -> PacketType {
339 PacketType::Subscribe
340 }
341
342 fn flags(&self) -> u8 {
343 0x02 }
345
346 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
347 buf.put_u16(self.packet_id);
348
349 if self.protocol_version == 5 {
350 self.properties.encode(buf)?;
351 }
352
353 if self.filters.is_empty() {
354 return Err(MqttError::MalformedPacket(
355 "SUBSCRIBE packet must contain at least one topic filter".to_string(),
356 ));
357 }
358
359 for filter in &self.filters {
360 encode_string(buf, &filter.filter)?;
361 if self.protocol_version == 5 {
362 buf.put_u8(filter.options.encode());
363 } else {
364 buf.put_u8(filter.options.qos as u8);
365 }
366 }
367
368 Ok(())
369 }
370
371 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
372 Self::decode_body_with_version(buf, fixed_header, 5)
373 }
374}
375
376impl SubscribePacket {
377 pub fn decode_body_with_version<B: Buf>(
383 buf: &mut B,
384 fixed_header: &FixedHeader,
385 protocol_version: u8,
386 ) -> Result<Self> {
387 ProtocolVersion::try_from(protocol_version)
388 .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
389
390 if fixed_header.flags != 0x02 {
391 return Err(MqttError::MalformedPacket(format!(
392 "Invalid SUBSCRIBE flags: expected 0x02, got 0x{:02X}",
393 fixed_header.flags
394 )));
395 }
396
397 if buf.remaining() < 2 {
398 return Err(MqttError::MalformedPacket(
399 "SUBSCRIBE missing packet identifier".to_string(),
400 ));
401 }
402 let packet_id = buf.get_u16();
403
404 let properties = if protocol_version == 5 {
405 Properties::decode(buf)?
406 } else {
407 Properties::default()
408 };
409
410 let mut filters = Vec::new();
411
412 if !buf.has_remaining() {
413 return Err(MqttError::MalformedPacket(
414 "SUBSCRIBE packet must contain at least one topic filter".to_string(),
415 ));
416 }
417
418 while buf.has_remaining() {
419 let filter_str = decode_string(buf)?;
420
421 if !buf.has_remaining() {
422 return Err(MqttError::MalformedPacket(
423 "Missing subscription options for topic filter".to_string(),
424 ));
425 }
426
427 let options_byte = buf.get_u8();
428 let options = if protocol_version == 5 {
429 SubscriptionOptions::decode(options_byte)?
430 } else {
431 SubscriptionOptions {
432 qos: QoS::from(options_byte & 0x03),
433 ..Default::default()
434 }
435 };
436
437 filters.push(TopicFilter {
438 filter: filter_str,
439 options,
440 });
441 }
442
443 Ok(Self {
444 packet_id,
445 filters,
446 properties,
447 protocol_version,
448 })
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use crate::protocol::v5::properties::PropertyId;
456 use bebytes::BeBytes;
457 use bytes::BytesMut;
458
459 #[cfg(test)]
460 mod hybrid_approach_tests {
461 use super::*;
462 use proptest::prelude::*;
463
464 #[test]
465 fn test_bebytes_vs_manual_encoding_identical() {
466 let test_cases = vec![
468 SubscriptionOptions::default(),
469 SubscriptionOptions {
470 qos: QoS::AtLeastOnce,
471 no_local: true,
472 retain_as_published: true,
473 retain_handling: RetainHandling::SendAtSubscribeIfNew,
474 },
475 SubscriptionOptions {
476 qos: QoS::ExactlyOnce,
477 no_local: false,
478 retain_as_published: true,
479 retain_handling: RetainHandling::DoNotSend,
480 },
481 ];
482
483 for options in test_cases {
484 let manual_encoded = options.encode();
485 let bebytes_encoded = options.encode_with_bebytes();
486
487 assert_eq!(
488 manual_encoded, bebytes_encoded,
489 "Manual and bebytes encoding should be identical for options: {options:?}"
490 );
491
492 let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
494 let bebytes_decoded =
495 SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
496
497 assert_eq!(manual_decoded, bebytes_decoded);
498 assert_eq!(manual_decoded, options);
499 }
500 }
501
502 #[test]
503 fn test_subscription_options_bits_round_trip() {
504 let options = SubscriptionOptions {
505 qos: QoS::AtLeastOnce,
506 no_local: true,
507 retain_as_published: false,
508 retain_handling: RetainHandling::SendAtSubscribeIfNew,
509 };
510
511 let bits = SubscriptionOptionsBits::from_options(&options);
512 let bytes = bits.to_be_bytes();
513 assert_eq!(bytes.len(), 1);
514
515 let (decoded_bits, consumed) =
516 SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
517 assert_eq!(consumed, 1);
518 assert_eq!(decoded_bits, bits);
519
520 let decoded_options = decoded_bits.to_options().unwrap();
521 assert_eq!(decoded_options, options);
522 }
523
524 #[test]
525 fn test_reserved_bits_validation() {
526 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
528
529 bits.reserved_bits = 1;
531 assert!(bits.to_options().is_err());
532
533 bits.reserved_bits = 2;
535 assert!(bits.to_options().is_err());
536 }
537
538 #[test]
539 fn test_invalid_qos_validation() {
540 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
541 bits.qos = 3; assert!(bits.to_options().is_err());
543 }
544
545 #[test]
546 fn test_invalid_retain_handling_validation() {
547 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
548 bits.retain_handling = 3; assert!(bits.to_options().is_err());
550 }
551
552 proptest! {
553 #[test]
554 fn prop_manual_vs_bebytes_encoding_consistency(
555 qos in 0u8..=2,
556 no_local: bool,
557 retain_as_published: bool,
558 retain_handling in 0u8..=2
559 ) {
560 let qos_enum = match qos {
561 0 => QoS::AtMostOnce,
562 1 => QoS::AtLeastOnce,
563 2 => QoS::ExactlyOnce,
564 _ => unreachable!(),
565 };
566
567 let retain_handling_enum = match retain_handling {
568 0 => RetainHandling::SendAtSubscribe,
569 1 => RetainHandling::SendAtSubscribeIfNew,
570 2 => RetainHandling::DoNotSend,
571 _ => unreachable!(),
572 };
573
574 let options = SubscriptionOptions {
575 qos: qos_enum,
576 no_local,
577 retain_as_published,
578 retain_handling: retain_handling_enum,
579 };
580
581 let manual_encoded = options.encode();
583 let bebytes_encoded = options.encode_with_bebytes();
584 prop_assert_eq!(manual_encoded, bebytes_encoded);
585
586 let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
588 let bebytes_decoded = SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
589 prop_assert_eq!(manual_decoded, bebytes_decoded);
590 prop_assert_eq!(manual_decoded, options);
591 }
592
593 #[test]
594 fn prop_bebytes_bit_field_round_trip(
595 qos in 0u8..=2,
596 no_local: bool,
597 retain_as_published: bool,
598 retain_handling in 0u8..=2
599 ) {
600 let bits = SubscriptionOptionsBits {
601 reserved_bits: 0,
602 retain_handling,
603 retain_as_published: u8::from(retain_as_published),
604 no_local: u8::from(no_local),
605 qos,
606 };
607
608 let bytes = bits.to_be_bytes();
609 let (decoded, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
610
611 prop_assert_eq!(consumed, 1);
612 prop_assert_eq!(decoded, bits);
613
614 let options = decoded.to_options().unwrap();
616 prop_assert_eq!(options.qos as u8, qos);
617 prop_assert_eq!(options.no_local, no_local);
618 prop_assert_eq!(options.retain_as_published, retain_as_published);
619 prop_assert_eq!(options.retain_handling as u8, retain_handling);
620 }
621 }
622 }
623
624 #[test]
625 fn test_subscription_options_encode_decode() {
626 let options = SubscriptionOptions {
627 qos: QoS::AtLeastOnce,
628 no_local: true,
629 retain_as_published: true,
630 retain_handling: RetainHandling::SendAtSubscribeIfNew,
631 };
632
633 let encoded = options.encode();
634 assert_eq!(encoded, 0x1D); let decoded = SubscriptionOptions::decode(encoded).unwrap();
637 assert_eq!(decoded, options);
638 }
639
640 #[test]
641 fn test_subscribe_basic() {
642 let packet = SubscribePacket::new(123)
643 .add_filter("temperature/+", QoS::AtLeastOnce)
644 .add_filter("humidity/#", QoS::ExactlyOnce);
645
646 assert_eq!(packet.packet_id, 123);
647 assert_eq!(packet.filters.len(), 2);
648 assert_eq!(packet.filters[0].filter, "temperature/+");
649 assert_eq!(packet.filters[0].options.qos, QoS::AtLeastOnce);
650 assert_eq!(packet.filters[1].filter, "humidity/#");
651 assert_eq!(packet.filters[1].options.qos, QoS::ExactlyOnce);
652 }
653
654 #[test]
655 fn test_subscribe_with_options() {
656 let options = SubscriptionOptions {
657 qos: QoS::AtLeastOnce,
658 no_local: true,
659 retain_as_published: false,
660 retain_handling: RetainHandling::DoNotSend,
661 };
662
663 let packet = SubscribePacket::new(456)
664 .add_filter_with_options(TopicFilter::with_options("test/topic", options));
665
666 assert!(packet.filters[0].options.no_local);
667 assert_eq!(
668 packet.filters[0].options.retain_handling,
669 RetainHandling::DoNotSend
670 );
671 }
672
673 #[test]
674 fn test_subscribe_encode_decode() {
675 let packet = SubscribePacket::new(789)
676 .add_filter("sensor/temp", QoS::AtMostOnce)
677 .add_filter("sensor/humidity", QoS::AtLeastOnce)
678 .with_subscription_identifier(42);
679
680 let mut buf = BytesMut::new();
681 packet.encode(&mut buf).unwrap();
682
683 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
684 assert_eq!(fixed_header.packet_type, PacketType::Subscribe);
685 assert_eq!(fixed_header.flags, 0x02);
686
687 let decoded = SubscribePacket::decode_body(&mut buf, &fixed_header).unwrap();
688 assert_eq!(decoded.packet_id, 789);
689 assert_eq!(decoded.filters.len(), 2);
690 assert_eq!(decoded.filters[0].filter, "sensor/temp");
691 assert_eq!(decoded.filters[0].options.qos, QoS::AtMostOnce);
692 assert_eq!(decoded.filters[1].filter, "sensor/humidity");
693 assert_eq!(decoded.filters[1].options.qos, QoS::AtLeastOnce);
694
695 let sub_id = decoded.properties.get(PropertyId::SubscriptionIdentifier);
696 assert!(sub_id.is_some());
697 }
698
699 #[test]
700 fn test_subscribe_invalid_flags() {
701 let mut buf = BytesMut::new();
702 buf.put_u16(123);
703
704 let fixed_header = FixedHeader::new(PacketType::Subscribe, 0x00, 2); let result = SubscribePacket::decode_body(&mut buf, &fixed_header);
706 assert!(result.is_err());
707 }
708
709 #[test]
710 fn test_subscribe_empty_filters() {
711 let packet = SubscribePacket::new(123);
712
713 let mut buf = BytesMut::new();
714 let result = packet.encode(&mut buf);
715 assert!(result.is_err());
716 }
717}