mqtt5_protocol/packet/
subscribe.rs

1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::packet::{FixedHeader, MqttPacket, PacketType};
4use crate::prelude::{format, String, ToString, Vec};
5use crate::protocol::v5::properties::Properties;
6use crate::types::ProtocolVersion;
7use crate::QoS;
8use bebytes::BeBytes;
9use bytes::{Buf, BufMut};
10
11/// Subscription options (v5.0)
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct SubscriptionOptions {
14    /// Maximum `QoS` level the client will accept
15    pub qos: QoS,
16    /// No Local option - if true, Application Messages MUST NOT be forwarded to this connection
17    pub no_local: bool,
18    /// Retain As Published - if true, keep the RETAIN flag as published
19    pub retain_as_published: bool,
20    /// Retain Handling option
21    pub retain_handling: RetainHandling,
22}
23
24/// Retain handling options
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[repr(u8)]
27pub enum RetainHandling {
28    /// Send retained messages at subscribe time
29    SendAtSubscribe = 0,
30    /// Send retained messages at subscribe time only if subscription doesn't exist
31    SendAtSubscribeIfNew = 1,
32    /// Don't send retained messages at subscribe time
33    DoNotSend = 2,
34}
35
36/// Subscription options using bebytes for bit field operations
37/// This demonstrates the hybrid approach for complex packet variable headers
38/// Bit fields are ordered from MSB to LSB (bits 7-0)
39#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
40pub struct SubscriptionOptionsBits {
41    /// Reserved bits (bits 7-6) - must be 0
42    #[bits(2)]
43    pub reserved_bits: u8,
44    /// Retain Handling (bits 5-4)
45    #[bits(2)]
46    pub retain_handling: u8,
47    /// Retain As Published flag (bit 3)
48    #[bits(1)]
49    pub retain_as_published: u8,
50    /// No Local flag (bit 2)
51    #[bits(1)]
52    pub no_local: u8,
53    /// `QoS` level (bits 1-0)
54    #[bits(2)]
55    pub qos: u8,
56}
57
58impl SubscriptionOptionsBits {
59    /// Creates subscription options bits from high-level `SubscriptionOptions`
60    /// Bebytes handles bit field layout, Rust handles type safety and validation
61    #[must_use]
62    pub fn from_options(options: &SubscriptionOptions) -> Self {
63        Self {
64            reserved_bits: 0,
65            retain_handling: options.retain_handling as u8,
66            retain_as_published: u8::from(options.retain_as_published),
67            no_local: u8::from(options.no_local),
68            qos: options.qos as u8,
69        }
70    }
71
72    /// Converts bebytes bit fields back to high-level `SubscriptionOptions`
73    /// Bebytes provides the bits, Rust handles validation and type conversion
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if reserved bits are set, or if `QoS` or retain handling values are invalid
78    pub fn to_options(&self) -> Result<SubscriptionOptions> {
79        // Validate reserved bits are zero
80        if self.reserved_bits != 0 {
81            return Err(MqttError::MalformedPacket(
82                "Reserved bits in subscription options must be 0".to_string(),
83            ));
84        }
85
86        // Validate and convert QoS
87        let qos = match self.qos {
88            0 => QoS::AtMostOnce,
89            1 => QoS::AtLeastOnce,
90            2 => QoS::ExactlyOnce,
91            _ => {
92                return Err(MqttError::MalformedPacket(format!(
93                    "Invalid QoS value in subscription options: {}",
94                    self.qos
95                )))
96            }
97        };
98
99        // Validate and convert retain handling
100        let retain_handling = match self.retain_handling {
101            0 => RetainHandling::SendAtSubscribe,
102            1 => RetainHandling::SendAtSubscribeIfNew,
103            2 => RetainHandling::DoNotSend,
104            _ => {
105                return Err(MqttError::MalformedPacket(format!(
106                    "Invalid retain handling value: {}",
107                    self.retain_handling
108                )))
109            }
110        };
111
112        Ok(SubscriptionOptions {
113            qos,
114            no_local: self.no_local != 0,
115            retain_as_published: self.retain_as_published != 0,
116            retain_handling,
117        })
118    }
119}
120
121impl Default for SubscriptionOptions {
122    fn default() -> Self {
123        Self {
124            qos: QoS::AtMostOnce,
125            no_local: false,
126            retain_as_published: false,
127            retain_handling: RetainHandling::SendAtSubscribe,
128        }
129    }
130}
131
132impl SubscriptionOptions {
133    /// Creates subscription options with the specified `QoS`
134    #[must_use]
135    pub fn new(qos: QoS) -> Self {
136        Self {
137            qos,
138            ..Default::default()
139        }
140    }
141
142    /// Sets the `QoS` level
143    #[must_use]
144    pub fn with_qos(mut self, qos: QoS) -> Self {
145        self.qos = qos;
146        self
147    }
148
149    /// Encodes subscription options as a byte (v5.0)
150    /// Original manual implementation for comparison
151    #[must_use]
152    pub fn encode(&self) -> u8 {
153        let mut byte = self.qos as u8;
154
155        if self.no_local {
156            byte |= 0x04;
157        }
158
159        if self.retain_as_published {
160            byte |= 0x08;
161        }
162
163        byte |= (self.retain_handling as u8) << 4;
164
165        byte
166    }
167
168    /// Encodes subscription options using bebytes (hybrid approach)
169    /// Bebytes handles bit field operations, Rust handles type safety
170    #[must_use]
171    pub fn encode_with_bebytes(&self) -> u8 {
172        let bits = SubscriptionOptionsBits::from_options(self);
173        bits.to_be_bytes()[0]
174    }
175
176    /// Decodes subscription options from a byte (v5.0)
177    /// Original manual implementation for comparison
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if the `QoS` value is invalid
182    pub fn decode(byte: u8) -> Result<Self> {
183        let qos_val = byte & crate::constants::subscription::QOS_MASK;
184        let qos = match qos_val {
185            0 => QoS::AtMostOnce,
186            1 => QoS::AtLeastOnce,
187            2 => QoS::ExactlyOnce,
188            _ => {
189                return Err(MqttError::MalformedPacket(format!(
190                    "Invalid QoS value in subscription options: {qos_val}"
191                )))
192            }
193        };
194
195        let no_local = (byte & crate::constants::subscription::NO_LOCAL_MASK) != 0;
196        let retain_as_published =
197            (byte & crate::constants::subscription::RETAIN_AS_PUBLISHED_MASK) != 0;
198
199        let retain_handling_val = (byte >> crate::constants::subscription::RETAIN_HANDLING_SHIFT)
200            & crate::constants::subscription::QOS_MASK;
201        let retain_handling = match retain_handling_val {
202            0 => RetainHandling::SendAtSubscribe,
203            1 => RetainHandling::SendAtSubscribeIfNew,
204            2 => RetainHandling::DoNotSend,
205            _ => {
206                return Err(MqttError::MalformedPacket(format!(
207                    "Invalid retain handling value: {retain_handling_val}"
208                )))
209            }
210        };
211
212        // Check reserved bits
213        if (byte & crate::constants::subscription::RESERVED_BITS_MASK) != 0 {
214            return Err(MqttError::MalformedPacket(
215                "Reserved bits in subscription options must be 0".to_string(),
216            ));
217        }
218
219        Ok(Self {
220            qos,
221            no_local,
222            retain_as_published,
223            retain_handling,
224        })
225    }
226
227    /// Decodes subscription options using bebytes (hybrid approach)\
228    /// Bebytes handles bit field extraction, Rust handles validation and type conversion
229    ///
230    /// # Errors
231    ///
232    /// Returns an error if the `QoS` value or retain handling is invalid, or reserved bits are set
233    pub fn decode_with_bebytes(byte: u8) -> Result<Self> {
234        let (bits, _consumed) =
235            SubscriptionOptionsBits::try_from_be_bytes(&[byte]).map_err(|e| {
236                MqttError::MalformedPacket(format!("Invalid subscription options byte: {e}"))
237            })?;
238
239        bits.to_options()
240    }
241}
242
243/// Topic filter with subscription options
244#[derive(Debug, Clone, PartialEq, Eq)]
245pub struct TopicFilter {
246    /// Topic filter string (may contain wildcards)
247    pub filter: String,
248    /// Subscription options
249    pub options: SubscriptionOptions,
250}
251
252impl TopicFilter {
253    /// Creates a new topic filter with the specified `QoS`
254    #[must_use]
255    pub fn new(filter: impl Into<String>, qos: QoS) -> Self {
256        Self {
257            filter: filter.into(),
258            options: SubscriptionOptions::new(qos),
259        }
260    }
261
262    /// Creates a new topic filter with custom options
263    #[must_use]
264    pub fn with_options(filter: impl Into<String>, options: SubscriptionOptions) -> Self {
265        Self {
266            filter: filter.into(),
267            options,
268        }
269    }
270}
271
272/// MQTT SUBSCRIBE packet
273#[derive(Debug, Clone)]
274pub struct SubscribePacket {
275    /// Packet identifier
276    pub packet_id: u16,
277    /// Topic filters to subscribe to
278    pub filters: Vec<TopicFilter>,
279    /// SUBSCRIBE properties (v5.0 only)
280    pub properties: Properties,
281    /// Protocol version (4 = v3.1.1, 5 = v5.0)
282    pub protocol_version: u8,
283}
284
285impl SubscribePacket {
286    /// Creates a new SUBSCRIBE packet (v5.0)
287    #[must_use]
288    pub fn new(packet_id: u16) -> Self {
289        Self {
290            packet_id,
291            filters: Vec::new(),
292            properties: Properties::default(),
293            protocol_version: 5,
294        }
295    }
296
297    /// Creates a new SUBSCRIBE packet for v3.1.1
298    #[must_use]
299    pub fn new_v311(packet_id: u16) -> Self {
300        Self {
301            packet_id,
302            filters: Vec::new(),
303            properties: Properties::default(),
304            protocol_version: 4,
305        }
306    }
307
308    /// Adds a topic filter
309    #[must_use]
310    pub fn add_filter(mut self, filter: impl Into<String>, qos: QoS) -> Self {
311        self.filters.push(TopicFilter::new(filter, qos));
312        self
313    }
314
315    /// Adds a topic filter with options
316    #[must_use]
317    pub fn add_filter_with_options(mut self, filter: TopicFilter) -> Self {
318        self.filters.push(filter);
319        self
320    }
321
322    /// Sets the subscription identifier
323    #[must_use]
324    pub fn with_subscription_identifier(mut self, id: u32) -> Self {
325        self.properties.set_subscription_identifier(id);
326        self
327    }
328
329    /// Adds a user property
330    #[must_use]
331    pub fn with_user_property(mut self, key: String, value: String) -> Self {
332        self.properties.add_user_property(key, value);
333        self
334    }
335}
336
337impl MqttPacket for SubscribePacket {
338    fn packet_type(&self) -> PacketType {
339        PacketType::Subscribe
340    }
341
342    fn flags(&self) -> u8 {
343        0x02 // SUBSCRIBE must have flags = 0x02
344    }
345
346    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
347        buf.put_u16(self.packet_id);
348
349        if self.protocol_version == 5 {
350            self.properties.encode(buf)?;
351        }
352
353        if self.filters.is_empty() {
354            return Err(MqttError::MalformedPacket(
355                "SUBSCRIBE packet must contain at least one topic filter".to_string(),
356            ));
357        }
358
359        for filter in &self.filters {
360            encode_string(buf, &filter.filter)?;
361            if self.protocol_version == 5 {
362                buf.put_u8(filter.options.encode());
363            } else {
364                buf.put_u8(filter.options.qos as u8);
365            }
366        }
367
368        Ok(())
369    }
370
371    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
372        Self::decode_body_with_version(buf, fixed_header, 5)
373    }
374}
375
376impl SubscribePacket {
377    /// Decodes the packet body with a specific protocol version
378    ///
379    /// # Errors
380    ///
381    /// Returns an error if decoding fails
382    pub fn decode_body_with_version<B: Buf>(
383        buf: &mut B,
384        fixed_header: &FixedHeader,
385        protocol_version: u8,
386    ) -> Result<Self> {
387        ProtocolVersion::try_from(protocol_version)
388            .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
389
390        if fixed_header.flags != 0x02 {
391            return Err(MqttError::MalformedPacket(format!(
392                "Invalid SUBSCRIBE flags: expected 0x02, got 0x{:02X}",
393                fixed_header.flags
394            )));
395        }
396
397        if buf.remaining() < 2 {
398            return Err(MqttError::MalformedPacket(
399                "SUBSCRIBE missing packet identifier".to_string(),
400            ));
401        }
402        let packet_id = buf.get_u16();
403
404        let properties = if protocol_version == 5 {
405            Properties::decode(buf)?
406        } else {
407            Properties::default()
408        };
409
410        let mut filters = Vec::new();
411
412        if !buf.has_remaining() {
413            return Err(MqttError::MalformedPacket(
414                "SUBSCRIBE packet must contain at least one topic filter".to_string(),
415            ));
416        }
417
418        while buf.has_remaining() {
419            let filter_str = decode_string(buf)?;
420
421            if !buf.has_remaining() {
422                return Err(MqttError::MalformedPacket(
423                    "Missing subscription options for topic filter".to_string(),
424                ));
425            }
426
427            let options_byte = buf.get_u8();
428            let options = if protocol_version == 5 {
429                SubscriptionOptions::decode(options_byte)?
430            } else {
431                SubscriptionOptions {
432                    qos: QoS::from(options_byte & 0x03),
433                    ..Default::default()
434                }
435            };
436
437            filters.push(TopicFilter {
438                filter: filter_str,
439                options,
440            });
441        }
442
443        Ok(Self {
444            packet_id,
445            filters,
446            properties,
447            protocol_version,
448        })
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use crate::protocol::v5::properties::PropertyId;
456    use bebytes::BeBytes;
457    use bytes::BytesMut;
458
459    #[cfg(test)]
460    mod hybrid_approach_tests {
461        use super::*;
462        use proptest::prelude::*;
463
464        #[test]
465        fn test_bebytes_vs_manual_encoding_identical() {
466            // Test that bebytes and manual implementations produce identical results
467            let test_cases = vec![
468                SubscriptionOptions::default(),
469                SubscriptionOptions {
470                    qos: QoS::AtLeastOnce,
471                    no_local: true,
472                    retain_as_published: true,
473                    retain_handling: RetainHandling::SendAtSubscribeIfNew,
474                },
475                SubscriptionOptions {
476                    qos: QoS::ExactlyOnce,
477                    no_local: false,
478                    retain_as_published: true,
479                    retain_handling: RetainHandling::DoNotSend,
480                },
481            ];
482
483            for options in test_cases {
484                let manual_encoded = options.encode();
485                let bebytes_encoded = options.encode_with_bebytes();
486
487                assert_eq!(
488                    manual_encoded, bebytes_encoded,
489                    "Manual and bebytes encoding should be identical for options: {options:?}"
490                );
491
492                // Also verify decoding produces same results
493                let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
494                let bebytes_decoded =
495                    SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
496
497                assert_eq!(manual_decoded, bebytes_decoded);
498                assert_eq!(manual_decoded, options);
499            }
500        }
501
502        #[test]
503        fn test_subscription_options_bits_round_trip() {
504            let options = SubscriptionOptions {
505                qos: QoS::AtLeastOnce,
506                no_local: true,
507                retain_as_published: false,
508                retain_handling: RetainHandling::SendAtSubscribeIfNew,
509            };
510
511            let bits = SubscriptionOptionsBits::from_options(&options);
512            let bytes = bits.to_be_bytes();
513            assert_eq!(bytes.len(), 1);
514
515            let (decoded_bits, consumed) =
516                SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
517            assert_eq!(consumed, 1);
518            assert_eq!(decoded_bits, bits);
519
520            let decoded_options = decoded_bits.to_options().unwrap();
521            assert_eq!(decoded_options, options);
522        }
523
524        #[test]
525        fn test_reserved_bits_validation() {
526            // Test that reserved bits being set causes validation errors
527            let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
528
529            // Set reserved bits
530            bits.reserved_bits = 1;
531            assert!(bits.to_options().is_err());
532
533            // Set different reserved bit pattern
534            bits.reserved_bits = 2;
535            assert!(bits.to_options().is_err());
536        }
537
538        #[test]
539        fn test_invalid_qos_validation() {
540            let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
541            bits.qos = 3; // Invalid QoS
542            assert!(bits.to_options().is_err());
543        }
544
545        #[test]
546        fn test_invalid_retain_handling_validation() {
547            let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
548            bits.retain_handling = 3; // Invalid retain handling
549            assert!(bits.to_options().is_err());
550        }
551
552        proptest! {
553            #[test]
554            fn prop_manual_vs_bebytes_encoding_consistency(
555                qos in 0u8..=2,
556                no_local: bool,
557                retain_as_published: bool,
558                retain_handling in 0u8..=2
559            ) {
560                let qos_enum = match qos {
561                    0 => QoS::AtMostOnce,
562                    1 => QoS::AtLeastOnce,
563                    2 => QoS::ExactlyOnce,
564                    _ => unreachable!(),
565                };
566
567                let retain_handling_enum = match retain_handling {
568                    0 => RetainHandling::SendAtSubscribe,
569                    1 => RetainHandling::SendAtSubscribeIfNew,
570                    2 => RetainHandling::DoNotSend,
571                    _ => unreachable!(),
572                };
573
574                let options = SubscriptionOptions {
575                    qos: qos_enum,
576                    no_local,
577                    retain_as_published,
578                    retain_handling: retain_handling_enum,
579                };
580
581                // Both encoding methods should produce identical results
582                let manual_encoded = options.encode();
583                let bebytes_encoded = options.encode_with_bebytes();
584                prop_assert_eq!(manual_encoded, bebytes_encoded);
585
586                // Both decoding methods should produce identical results
587                let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
588                let bebytes_decoded = SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
589                prop_assert_eq!(manual_decoded, bebytes_decoded);
590                prop_assert_eq!(manual_decoded, options);
591            }
592
593            #[test]
594            fn prop_bebytes_bit_field_round_trip(
595                qos in 0u8..=2,
596                no_local: bool,
597                retain_as_published: bool,
598                retain_handling in 0u8..=2
599            ) {
600                let bits = SubscriptionOptionsBits {
601                    reserved_bits: 0,
602                    retain_handling,
603                    retain_as_published: u8::from(retain_as_published),
604                    no_local: u8::from(no_local),
605                    qos,
606                };
607
608                let bytes = bits.to_be_bytes();
609                let (decoded, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
610
611                prop_assert_eq!(consumed, 1);
612                prop_assert_eq!(decoded, bits);
613
614                // Should be able to convert to high-level options
615                let options = decoded.to_options().unwrap();
616                prop_assert_eq!(options.qos as u8, qos);
617                prop_assert_eq!(options.no_local, no_local);
618                prop_assert_eq!(options.retain_as_published, retain_as_published);
619                prop_assert_eq!(options.retain_handling as u8, retain_handling);
620            }
621        }
622    }
623
624    #[test]
625    fn test_subscription_options_encode_decode() {
626        let options = SubscriptionOptions {
627            qos: QoS::AtLeastOnce,
628            no_local: true,
629            retain_as_published: true,
630            retain_handling: RetainHandling::SendAtSubscribeIfNew,
631        };
632
633        let encoded = options.encode();
634        assert_eq!(encoded, 0x1D); // QoS 1 + No Local + RAP + RH 1
635
636        let decoded = SubscriptionOptions::decode(encoded).unwrap();
637        assert_eq!(decoded, options);
638    }
639
640    #[test]
641    fn test_subscribe_basic() {
642        let packet = SubscribePacket::new(123)
643            .add_filter("temperature/+", QoS::AtLeastOnce)
644            .add_filter("humidity/#", QoS::ExactlyOnce);
645
646        assert_eq!(packet.packet_id, 123);
647        assert_eq!(packet.filters.len(), 2);
648        assert_eq!(packet.filters[0].filter, "temperature/+");
649        assert_eq!(packet.filters[0].options.qos, QoS::AtLeastOnce);
650        assert_eq!(packet.filters[1].filter, "humidity/#");
651        assert_eq!(packet.filters[1].options.qos, QoS::ExactlyOnce);
652    }
653
654    #[test]
655    fn test_subscribe_with_options() {
656        let options = SubscriptionOptions {
657            qos: QoS::AtLeastOnce,
658            no_local: true,
659            retain_as_published: false,
660            retain_handling: RetainHandling::DoNotSend,
661        };
662
663        let packet = SubscribePacket::new(456)
664            .add_filter_with_options(TopicFilter::with_options("test/topic", options));
665
666        assert!(packet.filters[0].options.no_local);
667        assert_eq!(
668            packet.filters[0].options.retain_handling,
669            RetainHandling::DoNotSend
670        );
671    }
672
673    #[test]
674    fn test_subscribe_encode_decode() {
675        let packet = SubscribePacket::new(789)
676            .add_filter("sensor/temp", QoS::AtMostOnce)
677            .add_filter("sensor/humidity", QoS::AtLeastOnce)
678            .with_subscription_identifier(42);
679
680        let mut buf = BytesMut::new();
681        packet.encode(&mut buf).unwrap();
682
683        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
684        assert_eq!(fixed_header.packet_type, PacketType::Subscribe);
685        assert_eq!(fixed_header.flags, 0x02);
686
687        let decoded = SubscribePacket::decode_body(&mut buf, &fixed_header).unwrap();
688        assert_eq!(decoded.packet_id, 789);
689        assert_eq!(decoded.filters.len(), 2);
690        assert_eq!(decoded.filters[0].filter, "sensor/temp");
691        assert_eq!(decoded.filters[0].options.qos, QoS::AtMostOnce);
692        assert_eq!(decoded.filters[1].filter, "sensor/humidity");
693        assert_eq!(decoded.filters[1].options.qos, QoS::AtLeastOnce);
694
695        let sub_id = decoded.properties.get(PropertyId::SubscriptionIdentifier);
696        assert!(sub_id.is_some());
697    }
698
699    #[test]
700    fn test_subscribe_invalid_flags() {
701        let mut buf = BytesMut::new();
702        buf.put_u16(123);
703
704        let fixed_header = FixedHeader::new(PacketType::Subscribe, 0x00, 2); // Wrong flags
705        let result = SubscribePacket::decode_body(&mut buf, &fixed_header);
706        assert!(result.is_err());
707    }
708
709    #[test]
710    fn test_subscribe_empty_filters() {
711        let packet = SubscribePacket::new(123);
712
713        let mut buf = BytesMut::new();
714        let result = packet.encode(&mut buf);
715        assert!(result.is_err());
716    }
717}