1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::packet::{FixedHeader, MqttPacket, PacketType};
4use crate::protocol::v5::properties::Properties;
5use crate::types::ProtocolVersion;
6use crate::QoS;
7use bebytes::BeBytes;
8use bytes::{Buf, BufMut};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct SubscriptionOptions {
13 pub qos: QoS,
15 pub no_local: bool,
17 pub retain_as_published: bool,
19 pub retain_handling: RetainHandling,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(u8)]
26pub enum RetainHandling {
27 SendAtSubscribe = 0,
29 SendAtSubscribeIfNew = 1,
31 DoNotSend = 2,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
39pub struct SubscriptionOptionsBits {
40 #[bits(2)]
42 pub reserved_bits: u8,
43 #[bits(2)]
45 pub retain_handling: u8,
46 #[bits(1)]
48 pub retain_as_published: u8,
49 #[bits(1)]
51 pub no_local: u8,
52 #[bits(2)]
54 pub qos: u8,
55}
56
57impl SubscriptionOptionsBits {
58 #[must_use]
61 pub fn from_options(options: &SubscriptionOptions) -> Self {
62 Self {
63 reserved_bits: 0,
64 retain_handling: options.retain_handling as u8,
65 retain_as_published: u8::from(options.retain_as_published),
66 no_local: u8::from(options.no_local),
67 qos: options.qos as u8,
68 }
69 }
70
71 pub fn to_options(&self) -> Result<SubscriptionOptions> {
78 if self.reserved_bits != 0 {
80 return Err(MqttError::MalformedPacket(
81 "Reserved bits in subscription options must be 0".to_string(),
82 ));
83 }
84
85 let qos = match self.qos {
87 0 => QoS::AtMostOnce,
88 1 => QoS::AtLeastOnce,
89 2 => QoS::ExactlyOnce,
90 _ => {
91 return Err(MqttError::MalformedPacket(format!(
92 "Invalid QoS value in subscription options: {}",
93 self.qos
94 )))
95 }
96 };
97
98 let retain_handling = match self.retain_handling {
100 0 => RetainHandling::SendAtSubscribe,
101 1 => RetainHandling::SendAtSubscribeIfNew,
102 2 => RetainHandling::DoNotSend,
103 _ => {
104 return Err(MqttError::MalformedPacket(format!(
105 "Invalid retain handling value: {}",
106 self.retain_handling
107 )))
108 }
109 };
110
111 Ok(SubscriptionOptions {
112 qos,
113 no_local: self.no_local != 0,
114 retain_as_published: self.retain_as_published != 0,
115 retain_handling,
116 })
117 }
118}
119
120impl Default for SubscriptionOptions {
121 fn default() -> Self {
122 Self {
123 qos: QoS::AtMostOnce,
124 no_local: false,
125 retain_as_published: false,
126 retain_handling: RetainHandling::SendAtSubscribe,
127 }
128 }
129}
130
131impl SubscriptionOptions {
132 #[must_use]
134 pub fn new(qos: QoS) -> Self {
135 Self {
136 qos,
137 ..Default::default()
138 }
139 }
140
141 #[must_use]
143 pub fn with_qos(mut self, qos: QoS) -> Self {
144 self.qos = qos;
145 self
146 }
147
148 #[must_use]
151 pub fn encode(&self) -> u8 {
152 let mut byte = self.qos as u8;
153
154 if self.no_local {
155 byte |= 0x04;
156 }
157
158 if self.retain_as_published {
159 byte |= 0x08;
160 }
161
162 byte |= (self.retain_handling as u8) << 4;
163
164 byte
165 }
166
167 #[must_use]
170 pub fn encode_with_bebytes(&self) -> u8 {
171 let bits = SubscriptionOptionsBits::from_options(self);
172 bits.to_be_bytes()[0]
173 }
174
175 pub fn decode(byte: u8) -> Result<Self> {
182 let qos_val = byte & crate::constants::subscription::QOS_MASK;
183 let qos = match qos_val {
184 0 => QoS::AtMostOnce,
185 1 => QoS::AtLeastOnce,
186 2 => QoS::ExactlyOnce,
187 _ => {
188 return Err(MqttError::MalformedPacket(format!(
189 "Invalid QoS value in subscription options: {qos_val}"
190 )))
191 }
192 };
193
194 let no_local = (byte & crate::constants::subscription::NO_LOCAL_MASK) != 0;
195 let retain_as_published =
196 (byte & crate::constants::subscription::RETAIN_AS_PUBLISHED_MASK) != 0;
197
198 let retain_handling_val = (byte >> crate::constants::subscription::RETAIN_HANDLING_SHIFT)
199 & crate::constants::subscription::QOS_MASK;
200 let retain_handling = match retain_handling_val {
201 0 => RetainHandling::SendAtSubscribe,
202 1 => RetainHandling::SendAtSubscribeIfNew,
203 2 => RetainHandling::DoNotSend,
204 _ => {
205 return Err(MqttError::MalformedPacket(format!(
206 "Invalid retain handling value: {retain_handling_val}"
207 )))
208 }
209 };
210
211 if (byte & crate::constants::subscription::RESERVED_BITS_MASK) != 0 {
213 return Err(MqttError::MalformedPacket(
214 "Reserved bits in subscription options must be 0".to_string(),
215 ));
216 }
217
218 Ok(Self {
219 qos,
220 no_local,
221 retain_as_published,
222 retain_handling,
223 })
224 }
225
226 pub fn decode_with_bebytes(byte: u8) -> Result<Self> {
233 let (bits, _consumed) =
234 SubscriptionOptionsBits::try_from_be_bytes(&[byte]).map_err(|e| {
235 MqttError::MalformedPacket(format!("Invalid subscription options byte: {e}"))
236 })?;
237
238 bits.to_options()
239 }
240}
241
242#[derive(Debug, Clone, PartialEq, Eq)]
244pub struct TopicFilter {
245 pub filter: String,
247 pub options: SubscriptionOptions,
249}
250
251impl TopicFilter {
252 #[must_use]
254 pub fn new(filter: impl Into<String>, qos: QoS) -> Self {
255 Self {
256 filter: filter.into(),
257 options: SubscriptionOptions::new(qos),
258 }
259 }
260
261 #[must_use]
263 pub fn with_options(filter: impl Into<String>, options: SubscriptionOptions) -> Self {
264 Self {
265 filter: filter.into(),
266 options,
267 }
268 }
269}
270
271#[derive(Debug, Clone)]
273pub struct SubscribePacket {
274 pub packet_id: u16,
276 pub filters: Vec<TopicFilter>,
278 pub properties: Properties,
280 pub protocol_version: u8,
282}
283
284impl SubscribePacket {
285 #[must_use]
287 pub fn new(packet_id: u16) -> Self {
288 Self {
289 packet_id,
290 filters: Vec::new(),
291 properties: Properties::default(),
292 protocol_version: 5,
293 }
294 }
295
296 #[must_use]
298 pub fn new_v311(packet_id: u16) -> Self {
299 Self {
300 packet_id,
301 filters: Vec::new(),
302 properties: Properties::default(),
303 protocol_version: 4,
304 }
305 }
306
307 #[must_use]
309 pub fn add_filter(mut self, filter: impl Into<String>, qos: QoS) -> Self {
310 self.filters.push(TopicFilter::new(filter, qos));
311 self
312 }
313
314 #[must_use]
316 pub fn add_filter_with_options(mut self, filter: TopicFilter) -> Self {
317 self.filters.push(filter);
318 self
319 }
320
321 #[must_use]
323 pub fn with_subscription_identifier(mut self, id: u32) -> Self {
324 self.properties.set_subscription_identifier(id);
325 self
326 }
327
328 #[must_use]
330 pub fn with_user_property(mut self, key: String, value: String) -> Self {
331 self.properties.add_user_property(key, value);
332 self
333 }
334}
335
336impl MqttPacket for SubscribePacket {
337 fn packet_type(&self) -> PacketType {
338 PacketType::Subscribe
339 }
340
341 fn flags(&self) -> u8 {
342 0x02 }
344
345 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
346 buf.put_u16(self.packet_id);
347
348 if self.protocol_version == 5 {
349 self.properties.encode(buf)?;
350 }
351
352 if self.filters.is_empty() {
353 return Err(MqttError::MalformedPacket(
354 "SUBSCRIBE packet must contain at least one topic filter".to_string(),
355 ));
356 }
357
358 for filter in &self.filters {
359 encode_string(buf, &filter.filter)?;
360 if self.protocol_version == 5 {
361 buf.put_u8(filter.options.encode());
362 } else {
363 buf.put_u8(filter.options.qos as u8);
364 }
365 }
366
367 Ok(())
368 }
369
370 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
371 Self::decode_body_with_version(buf, fixed_header, 5)
372 }
373}
374
375impl SubscribePacket {
376 pub fn decode_body_with_version<B: Buf>(
382 buf: &mut B,
383 fixed_header: &FixedHeader,
384 protocol_version: u8,
385 ) -> Result<Self> {
386 ProtocolVersion::try_from(protocol_version)
387 .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
388
389 if fixed_header.flags != 0x02 {
390 return Err(MqttError::MalformedPacket(format!(
391 "Invalid SUBSCRIBE flags: expected 0x02, got 0x{:02X}",
392 fixed_header.flags
393 )));
394 }
395
396 if buf.remaining() < 2 {
397 return Err(MqttError::MalformedPacket(
398 "SUBSCRIBE missing packet identifier".to_string(),
399 ));
400 }
401 let packet_id = buf.get_u16();
402
403 let properties = if protocol_version == 5 {
404 Properties::decode(buf)?
405 } else {
406 Properties::default()
407 };
408
409 let mut filters = Vec::new();
410
411 if !buf.has_remaining() {
412 return Err(MqttError::MalformedPacket(
413 "SUBSCRIBE packet must contain at least one topic filter".to_string(),
414 ));
415 }
416
417 while buf.has_remaining() {
418 let filter_str = decode_string(buf)?;
419
420 if !buf.has_remaining() {
421 return Err(MqttError::MalformedPacket(
422 "Missing subscription options for topic filter".to_string(),
423 ));
424 }
425
426 let options_byte = buf.get_u8();
427 let options = if protocol_version == 5 {
428 SubscriptionOptions::decode(options_byte)?
429 } else {
430 SubscriptionOptions {
431 qos: QoS::from(options_byte & 0x03),
432 ..Default::default()
433 }
434 };
435
436 filters.push(TopicFilter {
437 filter: filter_str,
438 options,
439 });
440 }
441
442 Ok(Self {
443 packet_id,
444 filters,
445 properties,
446 protocol_version,
447 })
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use crate::protocol::v5::properties::PropertyId;
455 use bebytes::BeBytes;
456 use bytes::BytesMut;
457
458 #[cfg(test)]
459 mod hybrid_approach_tests {
460 use super::*;
461 use proptest::prelude::*;
462
463 #[test]
464 fn test_bebytes_vs_manual_encoding_identical() {
465 let test_cases = vec![
467 SubscriptionOptions::default(),
468 SubscriptionOptions {
469 qos: QoS::AtLeastOnce,
470 no_local: true,
471 retain_as_published: true,
472 retain_handling: RetainHandling::SendAtSubscribeIfNew,
473 },
474 SubscriptionOptions {
475 qos: QoS::ExactlyOnce,
476 no_local: false,
477 retain_as_published: true,
478 retain_handling: RetainHandling::DoNotSend,
479 },
480 ];
481
482 for options in test_cases {
483 let manual_encoded = options.encode();
484 let bebytes_encoded = options.encode_with_bebytes();
485
486 assert_eq!(
487 manual_encoded, bebytes_encoded,
488 "Manual and bebytes encoding should be identical for options: {options:?}"
489 );
490
491 let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
493 let bebytes_decoded =
494 SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
495
496 assert_eq!(manual_decoded, bebytes_decoded);
497 assert_eq!(manual_decoded, options);
498 }
499 }
500
501 #[test]
502 fn test_subscription_options_bits_round_trip() {
503 let options = SubscriptionOptions {
504 qos: QoS::AtLeastOnce,
505 no_local: true,
506 retain_as_published: false,
507 retain_handling: RetainHandling::SendAtSubscribeIfNew,
508 };
509
510 let bits = SubscriptionOptionsBits::from_options(&options);
511 let bytes = bits.to_be_bytes();
512 assert_eq!(bytes.len(), 1);
513
514 let (decoded_bits, consumed) =
515 SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
516 assert_eq!(consumed, 1);
517 assert_eq!(decoded_bits, bits);
518
519 let decoded_options = decoded_bits.to_options().unwrap();
520 assert_eq!(decoded_options, options);
521 }
522
523 #[test]
524 fn test_reserved_bits_validation() {
525 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
527
528 bits.reserved_bits = 1;
530 assert!(bits.to_options().is_err());
531
532 bits.reserved_bits = 2;
534 assert!(bits.to_options().is_err());
535 }
536
537 #[test]
538 fn test_invalid_qos_validation() {
539 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
540 bits.qos = 3; assert!(bits.to_options().is_err());
542 }
543
544 #[test]
545 fn test_invalid_retain_handling_validation() {
546 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
547 bits.retain_handling = 3; assert!(bits.to_options().is_err());
549 }
550
551 proptest! {
552 #[test]
553 fn prop_manual_vs_bebytes_encoding_consistency(
554 qos in 0u8..=2,
555 no_local: bool,
556 retain_as_published: bool,
557 retain_handling in 0u8..=2
558 ) {
559 let qos_enum = match qos {
560 0 => QoS::AtMostOnce,
561 1 => QoS::AtLeastOnce,
562 2 => QoS::ExactlyOnce,
563 _ => unreachable!(),
564 };
565
566 let retain_handling_enum = match retain_handling {
567 0 => RetainHandling::SendAtSubscribe,
568 1 => RetainHandling::SendAtSubscribeIfNew,
569 2 => RetainHandling::DoNotSend,
570 _ => unreachable!(),
571 };
572
573 let options = SubscriptionOptions {
574 qos: qos_enum,
575 no_local,
576 retain_as_published,
577 retain_handling: retain_handling_enum,
578 };
579
580 let manual_encoded = options.encode();
582 let bebytes_encoded = options.encode_with_bebytes();
583 prop_assert_eq!(manual_encoded, bebytes_encoded);
584
585 let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
587 let bebytes_decoded = SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
588 prop_assert_eq!(manual_decoded, bebytes_decoded);
589 prop_assert_eq!(manual_decoded, options);
590 }
591
592 #[test]
593 fn prop_bebytes_bit_field_round_trip(
594 qos in 0u8..=2,
595 no_local: bool,
596 retain_as_published: bool,
597 retain_handling in 0u8..=2
598 ) {
599 let bits = SubscriptionOptionsBits {
600 reserved_bits: 0,
601 retain_handling,
602 retain_as_published: u8::from(retain_as_published),
603 no_local: u8::from(no_local),
604 qos,
605 };
606
607 let bytes = bits.to_be_bytes();
608 let (decoded, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
609
610 prop_assert_eq!(consumed, 1);
611 prop_assert_eq!(decoded, bits);
612
613 let options = decoded.to_options().unwrap();
615 prop_assert_eq!(options.qos as u8, qos);
616 prop_assert_eq!(options.no_local, no_local);
617 prop_assert_eq!(options.retain_as_published, retain_as_published);
618 prop_assert_eq!(options.retain_handling as u8, retain_handling);
619 }
620 }
621 }
622
623 #[test]
624 fn test_subscription_options_encode_decode() {
625 let options = SubscriptionOptions {
626 qos: QoS::AtLeastOnce,
627 no_local: true,
628 retain_as_published: true,
629 retain_handling: RetainHandling::SendAtSubscribeIfNew,
630 };
631
632 let encoded = options.encode();
633 assert_eq!(encoded, 0x1D); let decoded = SubscriptionOptions::decode(encoded).unwrap();
636 assert_eq!(decoded, options);
637 }
638
639 #[test]
640 fn test_subscribe_basic() {
641 let packet = SubscribePacket::new(123)
642 .add_filter("temperature/+", QoS::AtLeastOnce)
643 .add_filter("humidity/#", QoS::ExactlyOnce);
644
645 assert_eq!(packet.packet_id, 123);
646 assert_eq!(packet.filters.len(), 2);
647 assert_eq!(packet.filters[0].filter, "temperature/+");
648 assert_eq!(packet.filters[0].options.qos, QoS::AtLeastOnce);
649 assert_eq!(packet.filters[1].filter, "humidity/#");
650 assert_eq!(packet.filters[1].options.qos, QoS::ExactlyOnce);
651 }
652
653 #[test]
654 fn test_subscribe_with_options() {
655 let options = SubscriptionOptions {
656 qos: QoS::AtLeastOnce,
657 no_local: true,
658 retain_as_published: false,
659 retain_handling: RetainHandling::DoNotSend,
660 };
661
662 let packet = SubscribePacket::new(456)
663 .add_filter_with_options(TopicFilter::with_options("test/topic", options));
664
665 assert!(packet.filters[0].options.no_local);
666 assert_eq!(
667 packet.filters[0].options.retain_handling,
668 RetainHandling::DoNotSend
669 );
670 }
671
672 #[test]
673 fn test_subscribe_encode_decode() {
674 let packet = SubscribePacket::new(789)
675 .add_filter("sensor/temp", QoS::AtMostOnce)
676 .add_filter("sensor/humidity", QoS::AtLeastOnce)
677 .with_subscription_identifier(42);
678
679 let mut buf = BytesMut::new();
680 packet.encode(&mut buf).unwrap();
681
682 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
683 assert_eq!(fixed_header.packet_type, PacketType::Subscribe);
684 assert_eq!(fixed_header.flags, 0x02);
685
686 let decoded = SubscribePacket::decode_body(&mut buf, &fixed_header).unwrap();
687 assert_eq!(decoded.packet_id, 789);
688 assert_eq!(decoded.filters.len(), 2);
689 assert_eq!(decoded.filters[0].filter, "sensor/temp");
690 assert_eq!(decoded.filters[0].options.qos, QoS::AtMostOnce);
691 assert_eq!(decoded.filters[1].filter, "sensor/humidity");
692 assert_eq!(decoded.filters[1].options.qos, QoS::AtLeastOnce);
693
694 let sub_id = decoded.properties.get(PropertyId::SubscriptionIdentifier);
695 assert!(sub_id.is_some());
696 }
697
698 #[test]
699 fn test_subscribe_invalid_flags() {
700 let mut buf = BytesMut::new();
701 buf.put_u16(123);
702
703 let fixed_header = FixedHeader::new(PacketType::Subscribe, 0x00, 2); let result = SubscribePacket::decode_body(&mut buf, &fixed_header);
705 assert!(result.is_err());
706 }
707
708 #[test]
709 fn test_subscribe_empty_filters() {
710 let packet = SubscribePacket::new(123);
711
712 let mut buf = BytesMut::new();
713 let result = packet.encode(&mut buf);
714 assert!(result.is_err());
715 }
716}