1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::packet::{FixedHeader, MqttPacket, PacketType};
4use crate::protocol::v5::properties::Properties;
5use crate::QoS;
6use bebytes::BeBytes;
7use bytes::{Buf, BufMut};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct SubscriptionOptions {
12 pub qos: QoS,
14 pub no_local: bool,
16 pub retain_as_published: bool,
18 pub retain_handling: RetainHandling,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24#[repr(u8)]
25pub enum RetainHandling {
26 SendAtSubscribe = 0,
28 SendAtSubscribeIfNew = 1,
30 DoNotSend = 2,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
38pub struct SubscriptionOptionsBits {
39 #[bits(2)]
41 pub reserved_bits: u8,
42 #[bits(2)]
44 pub retain_handling: u8,
45 #[bits(1)]
47 pub retain_as_published: u8,
48 #[bits(1)]
50 pub no_local: u8,
51 #[bits(2)]
53 pub qos: u8,
54}
55
56impl SubscriptionOptionsBits {
57 #[must_use]
60 pub fn from_options(options: &SubscriptionOptions) -> Self {
61 Self {
62 reserved_bits: 0,
63 retain_handling: options.retain_handling as u8,
64 retain_as_published: u8::from(options.retain_as_published),
65 no_local: u8::from(options.no_local),
66 qos: options.qos as u8,
67 }
68 }
69
70 pub fn to_options(&self) -> Result<SubscriptionOptions> {
77 if self.reserved_bits != 0 {
79 return Err(MqttError::MalformedPacket(
80 "Reserved bits in subscription options must be 0".to_string(),
81 ));
82 }
83
84 let qos = match self.qos {
86 0 => QoS::AtMostOnce,
87 1 => QoS::AtLeastOnce,
88 2 => QoS::ExactlyOnce,
89 _ => {
90 return Err(MqttError::MalformedPacket(format!(
91 "Invalid QoS value in subscription options: {}",
92 self.qos
93 )))
94 }
95 };
96
97 let retain_handling = match self.retain_handling {
99 0 => RetainHandling::SendAtSubscribe,
100 1 => RetainHandling::SendAtSubscribeIfNew,
101 2 => RetainHandling::DoNotSend,
102 _ => {
103 return Err(MqttError::MalformedPacket(format!(
104 "Invalid retain handling value: {}",
105 self.retain_handling
106 )))
107 }
108 };
109
110 Ok(SubscriptionOptions {
111 qos,
112 no_local: self.no_local != 0,
113 retain_as_published: self.retain_as_published != 0,
114 retain_handling,
115 })
116 }
117}
118
119impl Default for SubscriptionOptions {
120 fn default() -> Self {
121 Self {
122 qos: QoS::AtMostOnce,
123 no_local: false,
124 retain_as_published: false,
125 retain_handling: RetainHandling::SendAtSubscribe,
126 }
127 }
128}
129
130impl SubscriptionOptions {
131 #[must_use]
133 pub fn new(qos: QoS) -> Self {
134 Self {
135 qos,
136 ..Default::default()
137 }
138 }
139
140 #[must_use]
142 pub fn with_qos(mut self, qos: QoS) -> Self {
143 self.qos = qos;
144 self
145 }
146
147 #[must_use]
150 pub fn encode(&self) -> u8 {
151 let mut byte = self.qos as u8;
152
153 if self.no_local {
154 byte |= 0x04;
155 }
156
157 if self.retain_as_published {
158 byte |= 0x08;
159 }
160
161 byte |= (self.retain_handling as u8) << 4;
162
163 byte
164 }
165
166 #[must_use]
169 pub fn encode_with_bebytes(&self) -> u8 {
170 let bits = SubscriptionOptionsBits::from_options(self);
171 bits.to_be_bytes()[0]
172 }
173
174 pub fn decode(byte: u8) -> Result<Self> {
181 let qos_val = byte & crate::constants::subscription::QOS_MASK;
182 let qos = match qos_val {
183 0 => QoS::AtMostOnce,
184 1 => QoS::AtLeastOnce,
185 2 => QoS::ExactlyOnce,
186 _ => {
187 return Err(MqttError::MalformedPacket(format!(
188 "Invalid QoS value in subscription options: {qos_val}"
189 )))
190 }
191 };
192
193 let no_local = (byte & crate::constants::subscription::NO_LOCAL_MASK) != 0;
194 let retain_as_published =
195 (byte & crate::constants::subscription::RETAIN_AS_PUBLISHED_MASK) != 0;
196
197 let retain_handling_val = (byte >> crate::constants::subscription::RETAIN_HANDLING_SHIFT)
198 & crate::constants::subscription::QOS_MASK;
199 let retain_handling = match retain_handling_val {
200 0 => RetainHandling::SendAtSubscribe,
201 1 => RetainHandling::SendAtSubscribeIfNew,
202 2 => RetainHandling::DoNotSend,
203 _ => {
204 return Err(MqttError::MalformedPacket(format!(
205 "Invalid retain handling value: {retain_handling_val}"
206 )))
207 }
208 };
209
210 if (byte & crate::constants::subscription::RESERVED_BITS_MASK) != 0 {
212 return Err(MqttError::MalformedPacket(
213 "Reserved bits in subscription options must be 0".to_string(),
214 ));
215 }
216
217 Ok(Self {
218 qos,
219 no_local,
220 retain_as_published,
221 retain_handling,
222 })
223 }
224
225 pub fn decode_with_bebytes(byte: u8) -> Result<Self> {
232 let (bits, _consumed) =
233 SubscriptionOptionsBits::try_from_be_bytes(&[byte]).map_err(|e| {
234 MqttError::MalformedPacket(format!("Invalid subscription options byte: {e}"))
235 })?;
236
237 bits.to_options()
238 }
239}
240
241#[derive(Debug, Clone, PartialEq, Eq)]
243pub struct TopicFilter {
244 pub filter: String,
246 pub options: SubscriptionOptions,
248}
249
250impl TopicFilter {
251 #[must_use]
253 pub fn new(filter: impl Into<String>, qos: QoS) -> Self {
254 Self {
255 filter: filter.into(),
256 options: SubscriptionOptions::new(qos),
257 }
258 }
259
260 #[must_use]
262 pub fn with_options(filter: impl Into<String>, options: SubscriptionOptions) -> Self {
263 Self {
264 filter: filter.into(),
265 options,
266 }
267 }
268}
269
270#[derive(Debug, Clone)]
272pub struct SubscribePacket {
273 pub packet_id: u16,
275 pub filters: Vec<TopicFilter>,
277 pub properties: Properties,
279}
280
281impl SubscribePacket {
282 #[must_use]
284 pub fn new(packet_id: u16) -> Self {
285 Self {
286 packet_id,
287 filters: Vec::new(),
288 properties: Properties::default(),
289 }
290 }
291
292 #[must_use]
294 pub fn add_filter(mut self, filter: impl Into<String>, qos: QoS) -> Self {
295 self.filters.push(TopicFilter::new(filter, qos));
296 self
297 }
298
299 #[must_use]
301 pub fn add_filter_with_options(mut self, filter: TopicFilter) -> Self {
302 self.filters.push(filter);
303 self
304 }
305
306 #[must_use]
308 pub fn with_subscription_identifier(mut self, id: u32) -> Self {
309 self.properties.set_subscription_identifier(id);
310 self
311 }
312
313 #[must_use]
315 pub fn with_user_property(mut self, key: String, value: String) -> Self {
316 self.properties.add_user_property(key, value);
317 self
318 }
319}
320
321impl MqttPacket for SubscribePacket {
322 fn packet_type(&self) -> PacketType {
323 PacketType::Subscribe
324 }
325
326 fn flags(&self) -> u8 {
327 0x02 }
329
330 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
331 buf.put_u16(self.packet_id);
333
334 self.properties.encode(buf)?;
336
337 if self.filters.is_empty() {
339 return Err(MqttError::MalformedPacket(
340 "SUBSCRIBE packet must contain at least one topic filter".to_string(),
341 ));
342 }
343
344 for filter in &self.filters {
345 encode_string(buf, &filter.filter)?;
346 buf.put_u8(filter.options.encode());
347 }
348
349 Ok(())
350 }
351
352 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
353 if fixed_header.flags != 0x02 {
355 return Err(MqttError::MalformedPacket(format!(
356 "Invalid SUBSCRIBE flags: expected 0x02, got 0x{:02X}",
357 fixed_header.flags
358 )));
359 }
360
361 if buf.remaining() < 2 {
363 return Err(MqttError::MalformedPacket(
364 "SUBSCRIBE missing packet identifier".to_string(),
365 ));
366 }
367 let packet_id = buf.get_u16();
368
369 let properties = Properties::decode(buf)?;
371
372 let mut filters = Vec::new();
374
375 if !buf.has_remaining() {
376 return Err(MqttError::MalformedPacket(
377 "SUBSCRIBE packet must contain at least one topic filter".to_string(),
378 ));
379 }
380
381 while buf.has_remaining() {
382 let filter_str = decode_string(buf)?;
383
384 if !buf.has_remaining() {
385 return Err(MqttError::MalformedPacket(
386 "Missing subscription options for topic filter".to_string(),
387 ));
388 }
389
390 let options_byte = buf.get_u8();
391 let options = SubscriptionOptions::decode(options_byte)?;
392
393 filters.push(TopicFilter {
394 filter: filter_str,
395 options,
396 });
397 }
398
399 Ok(Self {
400 packet_id,
401 filters,
402 properties,
403 })
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use crate::protocol::v5::properties::PropertyId;
411 use bebytes::BeBytes;
412 use bytes::BytesMut;
413
414 #[cfg(test)]
415 mod hybrid_approach_tests {
416 use super::*;
417 use proptest::prelude::*;
418
419 #[test]
420 fn test_bebytes_vs_manual_encoding_identical() {
421 let test_cases = vec![
423 SubscriptionOptions::default(),
424 SubscriptionOptions {
425 qos: QoS::AtLeastOnce,
426 no_local: true,
427 retain_as_published: true,
428 retain_handling: RetainHandling::SendAtSubscribeIfNew,
429 },
430 SubscriptionOptions {
431 qos: QoS::ExactlyOnce,
432 no_local: false,
433 retain_as_published: true,
434 retain_handling: RetainHandling::DoNotSend,
435 },
436 ];
437
438 for options in test_cases {
439 let manual_encoded = options.encode();
440 let bebytes_encoded = options.encode_with_bebytes();
441
442 assert_eq!(
443 manual_encoded, bebytes_encoded,
444 "Manual and bebytes encoding should be identical for options: {options:?}"
445 );
446
447 let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
449 let bebytes_decoded =
450 SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
451
452 assert_eq!(manual_decoded, bebytes_decoded);
453 assert_eq!(manual_decoded, options);
454 }
455 }
456
457 #[test]
458 fn test_subscription_options_bits_round_trip() {
459 let options = SubscriptionOptions {
460 qos: QoS::AtLeastOnce,
461 no_local: true,
462 retain_as_published: false,
463 retain_handling: RetainHandling::SendAtSubscribeIfNew,
464 };
465
466 let bits = SubscriptionOptionsBits::from_options(&options);
467 let bytes = bits.to_be_bytes();
468 assert_eq!(bytes.len(), 1);
469
470 let (decoded_bits, consumed) =
471 SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
472 assert_eq!(consumed, 1);
473 assert_eq!(decoded_bits, bits);
474
475 let decoded_options = decoded_bits.to_options().unwrap();
476 assert_eq!(decoded_options, options);
477 }
478
479 #[test]
480 fn test_reserved_bits_validation() {
481 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
483
484 bits.reserved_bits = 1;
486 assert!(bits.to_options().is_err());
487
488 bits.reserved_bits = 2;
490 assert!(bits.to_options().is_err());
491 }
492
493 #[test]
494 fn test_invalid_qos_validation() {
495 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
496 bits.qos = 3; assert!(bits.to_options().is_err());
498 }
499
500 #[test]
501 fn test_invalid_retain_handling_validation() {
502 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
503 bits.retain_handling = 3; assert!(bits.to_options().is_err());
505 }
506
507 proptest! {
508 #[test]
509 fn prop_manual_vs_bebytes_encoding_consistency(
510 qos in 0u8..=2,
511 no_local: bool,
512 retain_as_published: bool,
513 retain_handling in 0u8..=2
514 ) {
515 let qos_enum = match qos {
516 0 => QoS::AtMostOnce,
517 1 => QoS::AtLeastOnce,
518 2 => QoS::ExactlyOnce,
519 _ => unreachable!(),
520 };
521
522 let retain_handling_enum = match retain_handling {
523 0 => RetainHandling::SendAtSubscribe,
524 1 => RetainHandling::SendAtSubscribeIfNew,
525 2 => RetainHandling::DoNotSend,
526 _ => unreachable!(),
527 };
528
529 let options = SubscriptionOptions {
530 qos: qos_enum,
531 no_local,
532 retain_as_published,
533 retain_handling: retain_handling_enum,
534 };
535
536 let manual_encoded = options.encode();
538 let bebytes_encoded = options.encode_with_bebytes();
539 prop_assert_eq!(manual_encoded, bebytes_encoded);
540
541 let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
543 let bebytes_decoded = SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
544 prop_assert_eq!(manual_decoded, bebytes_decoded);
545 prop_assert_eq!(manual_decoded, options);
546 }
547
548 #[test]
549 fn prop_bebytes_bit_field_round_trip(
550 qos in 0u8..=2,
551 no_local: bool,
552 retain_as_published: bool,
553 retain_handling in 0u8..=2
554 ) {
555 let bits = SubscriptionOptionsBits {
556 reserved_bits: 0,
557 retain_handling,
558 retain_as_published: u8::from(retain_as_published),
559 no_local: u8::from(no_local),
560 qos,
561 };
562
563 let bytes = bits.to_be_bytes();
564 let (decoded, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
565
566 prop_assert_eq!(consumed, 1);
567 prop_assert_eq!(decoded, bits);
568
569 let options = decoded.to_options().unwrap();
571 prop_assert_eq!(options.qos as u8, qos);
572 prop_assert_eq!(options.no_local, no_local);
573 prop_assert_eq!(options.retain_as_published, retain_as_published);
574 prop_assert_eq!(options.retain_handling as u8, retain_handling);
575 }
576 }
577 }
578
579 #[test]
580 fn test_subscription_options_encode_decode() {
581 let options = SubscriptionOptions {
582 qos: QoS::AtLeastOnce,
583 no_local: true,
584 retain_as_published: true,
585 retain_handling: RetainHandling::SendAtSubscribeIfNew,
586 };
587
588 let encoded = options.encode();
589 assert_eq!(encoded, 0x1D); let decoded = SubscriptionOptions::decode(encoded).unwrap();
592 assert_eq!(decoded, options);
593 }
594
595 #[test]
596 fn test_subscribe_basic() {
597 let packet = SubscribePacket::new(123)
598 .add_filter("temperature/+", QoS::AtLeastOnce)
599 .add_filter("humidity/#", QoS::ExactlyOnce);
600
601 assert_eq!(packet.packet_id, 123);
602 assert_eq!(packet.filters.len(), 2);
603 assert_eq!(packet.filters[0].filter, "temperature/+");
604 assert_eq!(packet.filters[0].options.qos, QoS::AtLeastOnce);
605 assert_eq!(packet.filters[1].filter, "humidity/#");
606 assert_eq!(packet.filters[1].options.qos, QoS::ExactlyOnce);
607 }
608
609 #[test]
610 fn test_subscribe_with_options() {
611 let options = SubscriptionOptions {
612 qos: QoS::AtLeastOnce,
613 no_local: true,
614 retain_as_published: false,
615 retain_handling: RetainHandling::DoNotSend,
616 };
617
618 let packet = SubscribePacket::new(456)
619 .add_filter_with_options(TopicFilter::with_options("test/topic", options));
620
621 assert!(packet.filters[0].options.no_local);
622 assert_eq!(
623 packet.filters[0].options.retain_handling,
624 RetainHandling::DoNotSend
625 );
626 }
627
628 #[test]
629 fn test_subscribe_encode_decode() {
630 let packet = SubscribePacket::new(789)
631 .add_filter("sensor/temp", QoS::AtMostOnce)
632 .add_filter("sensor/humidity", QoS::AtLeastOnce)
633 .with_subscription_identifier(42);
634
635 let mut buf = BytesMut::new();
636 packet.encode(&mut buf).unwrap();
637
638 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
639 assert_eq!(fixed_header.packet_type, PacketType::Subscribe);
640 assert_eq!(fixed_header.flags, 0x02);
641
642 let decoded = SubscribePacket::decode_body(&mut buf, &fixed_header).unwrap();
643 assert_eq!(decoded.packet_id, 789);
644 assert_eq!(decoded.filters.len(), 2);
645 assert_eq!(decoded.filters[0].filter, "sensor/temp");
646 assert_eq!(decoded.filters[0].options.qos, QoS::AtMostOnce);
647 assert_eq!(decoded.filters[1].filter, "sensor/humidity");
648 assert_eq!(decoded.filters[1].options.qos, QoS::AtLeastOnce);
649
650 let sub_id = decoded.properties.get(PropertyId::SubscriptionIdentifier);
651 assert!(sub_id.is_some());
652 }
653
654 #[test]
655 fn test_subscribe_invalid_flags() {
656 let mut buf = BytesMut::new();
657 buf.put_u16(123);
658
659 let fixed_header = FixedHeader::new(PacketType::Subscribe, 0x00, 2); let result = SubscribePacket::decode_body(&mut buf, &fixed_header);
661 assert!(result.is_err());
662 }
663
664 #[test]
665 fn test_subscribe_empty_filters() {
666 let packet = SubscribePacket::new(123);
667
668 let mut buf = BytesMut::new();
669 let result = packet.encode(&mut buf);
670 assert!(result.is_err());
671 }
672}