mqtt5_protocol/packet/
subscribe.rs

1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::packet::{FixedHeader, MqttPacket, PacketType};
4use crate::protocol::v5::properties::Properties;
5use crate::types::ProtocolVersion;
6use crate::QoS;
7use bebytes::BeBytes;
8use bytes::{Buf, BufMut};
9
10/// Subscription options (v5.0)
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct SubscriptionOptions {
13    /// Maximum `QoS` level the client will accept
14    pub qos: QoS,
15    /// No Local option - if true, Application Messages MUST NOT be forwarded to this connection
16    pub no_local: bool,
17    /// Retain As Published - if true, keep the RETAIN flag as published
18    pub retain_as_published: bool,
19    /// Retain Handling option
20    pub retain_handling: RetainHandling,
21}
22
23/// Retain handling options
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(u8)]
26pub enum RetainHandling {
27    /// Send retained messages at subscribe time
28    SendAtSubscribe = 0,
29    /// Send retained messages at subscribe time only if subscription doesn't exist
30    SendAtSubscribeIfNew = 1,
31    /// Don't send retained messages at subscribe time
32    DoNotSend = 2,
33}
34
35/// Subscription options using bebytes for bit field operations
36/// This demonstrates the hybrid approach for complex packet variable headers
37/// Bit fields are ordered from MSB to LSB (bits 7-0)
38#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
39pub struct SubscriptionOptionsBits {
40    /// Reserved bits (bits 7-6) - must be 0
41    #[bits(2)]
42    pub reserved_bits: u8,
43    /// Retain Handling (bits 5-4)
44    #[bits(2)]
45    pub retain_handling: u8,
46    /// Retain As Published flag (bit 3)
47    #[bits(1)]
48    pub retain_as_published: u8,
49    /// No Local flag (bit 2)
50    #[bits(1)]
51    pub no_local: u8,
52    /// `QoS` level (bits 1-0)
53    #[bits(2)]
54    pub qos: u8,
55}
56
57impl SubscriptionOptionsBits {
58    /// Creates subscription options bits from high-level `SubscriptionOptions`
59    /// Bebytes handles bit field layout, Rust handles type safety and validation
60    #[must_use]
61    pub fn from_options(options: &SubscriptionOptions) -> Self {
62        Self {
63            reserved_bits: 0,
64            retain_handling: options.retain_handling as u8,
65            retain_as_published: u8::from(options.retain_as_published),
66            no_local: u8::from(options.no_local),
67            qos: options.qos as u8,
68        }
69    }
70
71    /// Converts bebytes bit fields back to high-level `SubscriptionOptions`
72    /// Bebytes provides the bits, Rust handles validation and type conversion
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if reserved bits are set, or if `QoS` or retain handling values are invalid
77    pub fn to_options(&self) -> Result<SubscriptionOptions> {
78        // Validate reserved bits are zero
79        if self.reserved_bits != 0 {
80            return Err(MqttError::MalformedPacket(
81                "Reserved bits in subscription options must be 0".to_string(),
82            ));
83        }
84
85        // Validate and convert QoS
86        let qos = match self.qos {
87            0 => QoS::AtMostOnce,
88            1 => QoS::AtLeastOnce,
89            2 => QoS::ExactlyOnce,
90            _ => {
91                return Err(MqttError::MalformedPacket(format!(
92                    "Invalid QoS value in subscription options: {}",
93                    self.qos
94                )))
95            }
96        };
97
98        // Validate and convert retain handling
99        let retain_handling = match self.retain_handling {
100            0 => RetainHandling::SendAtSubscribe,
101            1 => RetainHandling::SendAtSubscribeIfNew,
102            2 => RetainHandling::DoNotSend,
103            _ => {
104                return Err(MqttError::MalformedPacket(format!(
105                    "Invalid retain handling value: {}",
106                    self.retain_handling
107                )))
108            }
109        };
110
111        Ok(SubscriptionOptions {
112            qos,
113            no_local: self.no_local != 0,
114            retain_as_published: self.retain_as_published != 0,
115            retain_handling,
116        })
117    }
118}
119
120impl Default for SubscriptionOptions {
121    fn default() -> Self {
122        Self {
123            qos: QoS::AtMostOnce,
124            no_local: false,
125            retain_as_published: false,
126            retain_handling: RetainHandling::SendAtSubscribe,
127        }
128    }
129}
130
131impl SubscriptionOptions {
132    /// Creates subscription options with the specified `QoS`
133    #[must_use]
134    pub fn new(qos: QoS) -> Self {
135        Self {
136            qos,
137            ..Default::default()
138        }
139    }
140
141    /// Sets the `QoS` level
142    #[must_use]
143    pub fn with_qos(mut self, qos: QoS) -> Self {
144        self.qos = qos;
145        self
146    }
147
148    /// Encodes subscription options as a byte (v5.0)
149    /// Original manual implementation for comparison
150    #[must_use]
151    pub fn encode(&self) -> u8 {
152        let mut byte = self.qos as u8;
153
154        if self.no_local {
155            byte |= 0x04;
156        }
157
158        if self.retain_as_published {
159            byte |= 0x08;
160        }
161
162        byte |= (self.retain_handling as u8) << 4;
163
164        byte
165    }
166
167    /// Encodes subscription options using bebytes (hybrid approach)
168    /// Bebytes handles bit field operations, Rust handles type safety
169    #[must_use]
170    pub fn encode_with_bebytes(&self) -> u8 {
171        let bits = SubscriptionOptionsBits::from_options(self);
172        bits.to_be_bytes()[0]
173    }
174
175    /// Decodes subscription options from a byte (v5.0)
176    /// Original manual implementation for comparison
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if the `QoS` value is invalid
181    pub fn decode(byte: u8) -> Result<Self> {
182        let qos_val = byte & crate::constants::subscription::QOS_MASK;
183        let qos = match qos_val {
184            0 => QoS::AtMostOnce,
185            1 => QoS::AtLeastOnce,
186            2 => QoS::ExactlyOnce,
187            _ => {
188                return Err(MqttError::MalformedPacket(format!(
189                    "Invalid QoS value in subscription options: {qos_val}"
190                )))
191            }
192        };
193
194        let no_local = (byte & crate::constants::subscription::NO_LOCAL_MASK) != 0;
195        let retain_as_published =
196            (byte & crate::constants::subscription::RETAIN_AS_PUBLISHED_MASK) != 0;
197
198        let retain_handling_val = (byte >> crate::constants::subscription::RETAIN_HANDLING_SHIFT)
199            & crate::constants::subscription::QOS_MASK;
200        let retain_handling = match retain_handling_val {
201            0 => RetainHandling::SendAtSubscribe,
202            1 => RetainHandling::SendAtSubscribeIfNew,
203            2 => RetainHandling::DoNotSend,
204            _ => {
205                return Err(MqttError::MalformedPacket(format!(
206                    "Invalid retain handling value: {retain_handling_val}"
207                )))
208            }
209        };
210
211        // Check reserved bits
212        if (byte & crate::constants::subscription::RESERVED_BITS_MASK) != 0 {
213            return Err(MqttError::MalformedPacket(
214                "Reserved bits in subscription options must be 0".to_string(),
215            ));
216        }
217
218        Ok(Self {
219            qos,
220            no_local,
221            retain_as_published,
222            retain_handling,
223        })
224    }
225
226    /// Decodes subscription options using bebytes (hybrid approach)\
227    /// Bebytes handles bit field extraction, Rust handles validation and type conversion
228    ///
229    /// # Errors
230    ///
231    /// Returns an error if the `QoS` value or retain handling is invalid, or reserved bits are set
232    pub fn decode_with_bebytes(byte: u8) -> Result<Self> {
233        let (bits, _consumed) =
234            SubscriptionOptionsBits::try_from_be_bytes(&[byte]).map_err(|e| {
235                MqttError::MalformedPacket(format!("Invalid subscription options byte: {e}"))
236            })?;
237
238        bits.to_options()
239    }
240}
241
242/// Topic filter with subscription options
243#[derive(Debug, Clone, PartialEq, Eq)]
244pub struct TopicFilter {
245    /// Topic filter string (may contain wildcards)
246    pub filter: String,
247    /// Subscription options
248    pub options: SubscriptionOptions,
249}
250
251impl TopicFilter {
252    /// Creates a new topic filter with the specified `QoS`
253    #[must_use]
254    pub fn new(filter: impl Into<String>, qos: QoS) -> Self {
255        Self {
256            filter: filter.into(),
257            options: SubscriptionOptions::new(qos),
258        }
259    }
260
261    /// Creates a new topic filter with custom options
262    #[must_use]
263    pub fn with_options(filter: impl Into<String>, options: SubscriptionOptions) -> Self {
264        Self {
265            filter: filter.into(),
266            options,
267        }
268    }
269}
270
271/// MQTT SUBSCRIBE packet
272#[derive(Debug, Clone)]
273pub struct SubscribePacket {
274    /// Packet identifier
275    pub packet_id: u16,
276    /// Topic filters to subscribe to
277    pub filters: Vec<TopicFilter>,
278    /// SUBSCRIBE properties (v5.0 only)
279    pub properties: Properties,
280    /// Protocol version (4 = v3.1.1, 5 = v5.0)
281    pub protocol_version: u8,
282}
283
284impl SubscribePacket {
285    /// Creates a new SUBSCRIBE packet (v5.0)
286    #[must_use]
287    pub fn new(packet_id: u16) -> Self {
288        Self {
289            packet_id,
290            filters: Vec::new(),
291            properties: Properties::default(),
292            protocol_version: 5,
293        }
294    }
295
296    /// Creates a new SUBSCRIBE packet for v3.1.1
297    #[must_use]
298    pub fn new_v311(packet_id: u16) -> Self {
299        Self {
300            packet_id,
301            filters: Vec::new(),
302            properties: Properties::default(),
303            protocol_version: 4,
304        }
305    }
306
307    /// Adds a topic filter
308    #[must_use]
309    pub fn add_filter(mut self, filter: impl Into<String>, qos: QoS) -> Self {
310        self.filters.push(TopicFilter::new(filter, qos));
311        self
312    }
313
314    /// Adds a topic filter with options
315    #[must_use]
316    pub fn add_filter_with_options(mut self, filter: TopicFilter) -> Self {
317        self.filters.push(filter);
318        self
319    }
320
321    /// Sets the subscription identifier
322    #[must_use]
323    pub fn with_subscription_identifier(mut self, id: u32) -> Self {
324        self.properties.set_subscription_identifier(id);
325        self
326    }
327
328    /// Adds a user property
329    #[must_use]
330    pub fn with_user_property(mut self, key: String, value: String) -> Self {
331        self.properties.add_user_property(key, value);
332        self
333    }
334}
335
336impl MqttPacket for SubscribePacket {
337    fn packet_type(&self) -> PacketType {
338        PacketType::Subscribe
339    }
340
341    fn flags(&self) -> u8 {
342        0x02 // SUBSCRIBE must have flags = 0x02
343    }
344
345    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
346        buf.put_u16(self.packet_id);
347
348        if self.protocol_version == 5 {
349            self.properties.encode(buf)?;
350        }
351
352        if self.filters.is_empty() {
353            return Err(MqttError::MalformedPacket(
354                "SUBSCRIBE packet must contain at least one topic filter".to_string(),
355            ));
356        }
357
358        for filter in &self.filters {
359            encode_string(buf, &filter.filter)?;
360            if self.protocol_version == 5 {
361                buf.put_u8(filter.options.encode());
362            } else {
363                buf.put_u8(filter.options.qos as u8);
364            }
365        }
366
367        Ok(())
368    }
369
370    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
371        Self::decode_body_with_version(buf, fixed_header, 5)
372    }
373}
374
375impl SubscribePacket {
376    /// Decodes the packet body with a specific protocol version
377    ///
378    /// # Errors
379    ///
380    /// Returns an error if decoding fails
381    pub fn decode_body_with_version<B: Buf>(
382        buf: &mut B,
383        fixed_header: &FixedHeader,
384        protocol_version: u8,
385    ) -> Result<Self> {
386        ProtocolVersion::try_from(protocol_version)
387            .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
388
389        if fixed_header.flags != 0x02 {
390            return Err(MqttError::MalformedPacket(format!(
391                "Invalid SUBSCRIBE flags: expected 0x02, got 0x{:02X}",
392                fixed_header.flags
393            )));
394        }
395
396        if buf.remaining() < 2 {
397            return Err(MqttError::MalformedPacket(
398                "SUBSCRIBE missing packet identifier".to_string(),
399            ));
400        }
401        let packet_id = buf.get_u16();
402
403        let properties = if protocol_version == 5 {
404            Properties::decode(buf)?
405        } else {
406            Properties::default()
407        };
408
409        let mut filters = Vec::new();
410
411        if !buf.has_remaining() {
412            return Err(MqttError::MalformedPacket(
413                "SUBSCRIBE packet must contain at least one topic filter".to_string(),
414            ));
415        }
416
417        while buf.has_remaining() {
418            let filter_str = decode_string(buf)?;
419
420            if !buf.has_remaining() {
421                return Err(MqttError::MalformedPacket(
422                    "Missing subscription options for topic filter".to_string(),
423                ));
424            }
425
426            let options_byte = buf.get_u8();
427            let options = if protocol_version == 5 {
428                SubscriptionOptions::decode(options_byte)?
429            } else {
430                SubscriptionOptions {
431                    qos: QoS::from(options_byte & 0x03),
432                    ..Default::default()
433                }
434            };
435
436            filters.push(TopicFilter {
437                filter: filter_str,
438                options,
439            });
440        }
441
442        Ok(Self {
443            packet_id,
444            filters,
445            properties,
446            protocol_version,
447        })
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use crate::protocol::v5::properties::PropertyId;
455    use bebytes::BeBytes;
456    use bytes::BytesMut;
457
458    #[cfg(test)]
459    mod hybrid_approach_tests {
460        use super::*;
461        use proptest::prelude::*;
462
463        #[test]
464        fn test_bebytes_vs_manual_encoding_identical() {
465            // Test that bebytes and manual implementations produce identical results
466            let test_cases = vec![
467                SubscriptionOptions::default(),
468                SubscriptionOptions {
469                    qos: QoS::AtLeastOnce,
470                    no_local: true,
471                    retain_as_published: true,
472                    retain_handling: RetainHandling::SendAtSubscribeIfNew,
473                },
474                SubscriptionOptions {
475                    qos: QoS::ExactlyOnce,
476                    no_local: false,
477                    retain_as_published: true,
478                    retain_handling: RetainHandling::DoNotSend,
479                },
480            ];
481
482            for options in test_cases {
483                let manual_encoded = options.encode();
484                let bebytes_encoded = options.encode_with_bebytes();
485
486                assert_eq!(
487                    manual_encoded, bebytes_encoded,
488                    "Manual and bebytes encoding should be identical for options: {options:?}"
489                );
490
491                // Also verify decoding produces same results
492                let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
493                let bebytes_decoded =
494                    SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
495
496                assert_eq!(manual_decoded, bebytes_decoded);
497                assert_eq!(manual_decoded, options);
498            }
499        }
500
501        #[test]
502        fn test_subscription_options_bits_round_trip() {
503            let options = SubscriptionOptions {
504                qos: QoS::AtLeastOnce,
505                no_local: true,
506                retain_as_published: false,
507                retain_handling: RetainHandling::SendAtSubscribeIfNew,
508            };
509
510            let bits = SubscriptionOptionsBits::from_options(&options);
511            let bytes = bits.to_be_bytes();
512            assert_eq!(bytes.len(), 1);
513
514            let (decoded_bits, consumed) =
515                SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
516            assert_eq!(consumed, 1);
517            assert_eq!(decoded_bits, bits);
518
519            let decoded_options = decoded_bits.to_options().unwrap();
520            assert_eq!(decoded_options, options);
521        }
522
523        #[test]
524        fn test_reserved_bits_validation() {
525            // Test that reserved bits being set causes validation errors
526            let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
527
528            // Set reserved bits
529            bits.reserved_bits = 1;
530            assert!(bits.to_options().is_err());
531
532            // Set different reserved bit pattern
533            bits.reserved_bits = 2;
534            assert!(bits.to_options().is_err());
535        }
536
537        #[test]
538        fn test_invalid_qos_validation() {
539            let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
540            bits.qos = 3; // Invalid QoS
541            assert!(bits.to_options().is_err());
542        }
543
544        #[test]
545        fn test_invalid_retain_handling_validation() {
546            let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
547            bits.retain_handling = 3; // Invalid retain handling
548            assert!(bits.to_options().is_err());
549        }
550
551        proptest! {
552            #[test]
553            fn prop_manual_vs_bebytes_encoding_consistency(
554                qos in 0u8..=2,
555                no_local: bool,
556                retain_as_published: bool,
557                retain_handling in 0u8..=2
558            ) {
559                let qos_enum = match qos {
560                    0 => QoS::AtMostOnce,
561                    1 => QoS::AtLeastOnce,
562                    2 => QoS::ExactlyOnce,
563                    _ => unreachable!(),
564                };
565
566                let retain_handling_enum = match retain_handling {
567                    0 => RetainHandling::SendAtSubscribe,
568                    1 => RetainHandling::SendAtSubscribeIfNew,
569                    2 => RetainHandling::DoNotSend,
570                    _ => unreachable!(),
571                };
572
573                let options = SubscriptionOptions {
574                    qos: qos_enum,
575                    no_local,
576                    retain_as_published,
577                    retain_handling: retain_handling_enum,
578                };
579
580                // Both encoding methods should produce identical results
581                let manual_encoded = options.encode();
582                let bebytes_encoded = options.encode_with_bebytes();
583                prop_assert_eq!(manual_encoded, bebytes_encoded);
584
585                // Both decoding methods should produce identical results
586                let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
587                let bebytes_decoded = SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
588                prop_assert_eq!(manual_decoded, bebytes_decoded);
589                prop_assert_eq!(manual_decoded, options);
590            }
591
592            #[test]
593            fn prop_bebytes_bit_field_round_trip(
594                qos in 0u8..=2,
595                no_local: bool,
596                retain_as_published: bool,
597                retain_handling in 0u8..=2
598            ) {
599                let bits = SubscriptionOptionsBits {
600                    reserved_bits: 0,
601                    retain_handling,
602                    retain_as_published: u8::from(retain_as_published),
603                    no_local: u8::from(no_local),
604                    qos,
605                };
606
607                let bytes = bits.to_be_bytes();
608                let (decoded, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
609
610                prop_assert_eq!(consumed, 1);
611                prop_assert_eq!(decoded, bits);
612
613                // Should be able to convert to high-level options
614                let options = decoded.to_options().unwrap();
615                prop_assert_eq!(options.qos as u8, qos);
616                prop_assert_eq!(options.no_local, no_local);
617                prop_assert_eq!(options.retain_as_published, retain_as_published);
618                prop_assert_eq!(options.retain_handling as u8, retain_handling);
619            }
620        }
621    }
622
623    #[test]
624    fn test_subscription_options_encode_decode() {
625        let options = SubscriptionOptions {
626            qos: QoS::AtLeastOnce,
627            no_local: true,
628            retain_as_published: true,
629            retain_handling: RetainHandling::SendAtSubscribeIfNew,
630        };
631
632        let encoded = options.encode();
633        assert_eq!(encoded, 0x1D); // QoS 1 + No Local + RAP + RH 1
634
635        let decoded = SubscriptionOptions::decode(encoded).unwrap();
636        assert_eq!(decoded, options);
637    }
638
639    #[test]
640    fn test_subscribe_basic() {
641        let packet = SubscribePacket::new(123)
642            .add_filter("temperature/+", QoS::AtLeastOnce)
643            .add_filter("humidity/#", QoS::ExactlyOnce);
644
645        assert_eq!(packet.packet_id, 123);
646        assert_eq!(packet.filters.len(), 2);
647        assert_eq!(packet.filters[0].filter, "temperature/+");
648        assert_eq!(packet.filters[0].options.qos, QoS::AtLeastOnce);
649        assert_eq!(packet.filters[1].filter, "humidity/#");
650        assert_eq!(packet.filters[1].options.qos, QoS::ExactlyOnce);
651    }
652
653    #[test]
654    fn test_subscribe_with_options() {
655        let options = SubscriptionOptions {
656            qos: QoS::AtLeastOnce,
657            no_local: true,
658            retain_as_published: false,
659            retain_handling: RetainHandling::DoNotSend,
660        };
661
662        let packet = SubscribePacket::new(456)
663            .add_filter_with_options(TopicFilter::with_options("test/topic", options));
664
665        assert!(packet.filters[0].options.no_local);
666        assert_eq!(
667            packet.filters[0].options.retain_handling,
668            RetainHandling::DoNotSend
669        );
670    }
671
672    #[test]
673    fn test_subscribe_encode_decode() {
674        let packet = SubscribePacket::new(789)
675            .add_filter("sensor/temp", QoS::AtMostOnce)
676            .add_filter("sensor/humidity", QoS::AtLeastOnce)
677            .with_subscription_identifier(42);
678
679        let mut buf = BytesMut::new();
680        packet.encode(&mut buf).unwrap();
681
682        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
683        assert_eq!(fixed_header.packet_type, PacketType::Subscribe);
684        assert_eq!(fixed_header.flags, 0x02);
685
686        let decoded = SubscribePacket::decode_body(&mut buf, &fixed_header).unwrap();
687        assert_eq!(decoded.packet_id, 789);
688        assert_eq!(decoded.filters.len(), 2);
689        assert_eq!(decoded.filters[0].filter, "sensor/temp");
690        assert_eq!(decoded.filters[0].options.qos, QoS::AtMostOnce);
691        assert_eq!(decoded.filters[1].filter, "sensor/humidity");
692        assert_eq!(decoded.filters[1].options.qos, QoS::AtLeastOnce);
693
694        let sub_id = decoded.properties.get(PropertyId::SubscriptionIdentifier);
695        assert!(sub_id.is_some());
696    }
697
698    #[test]
699    fn test_subscribe_invalid_flags() {
700        let mut buf = BytesMut::new();
701        buf.put_u16(123);
702
703        let fixed_header = FixedHeader::new(PacketType::Subscribe, 0x00, 2); // Wrong flags
704        let result = SubscribePacket::decode_body(&mut buf, &fixed_header);
705        assert!(result.is_err());
706    }
707
708    #[test]
709    fn test_subscribe_empty_filters() {
710        let packet = SubscribePacket::new(123);
711
712        let mut buf = BytesMut::new();
713        let result = packet.encode(&mut buf);
714        assert!(result.is_err());
715    }
716}