mqtt5_protocol/packet/
subscribe.rs

1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::packet::{FixedHeader, MqttPacket, PacketType};
4use crate::protocol::v5::properties::Properties;
5use crate::QoS;
6use bebytes::BeBytes;
7use bytes::{Buf, BufMut};
8
9/// Subscription options (v5.0)
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct SubscriptionOptions {
12    /// Maximum `QoS` level the client will accept
13    pub qos: QoS,
14    /// No Local option - if true, Application Messages MUST NOT be forwarded to this connection
15    pub no_local: bool,
16    /// Retain As Published - if true, keep the RETAIN flag as published
17    pub retain_as_published: bool,
18    /// Retain Handling option
19    pub retain_handling: RetainHandling,
20}
21
22/// Retain handling options
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24#[repr(u8)]
25pub enum RetainHandling {
26    /// Send retained messages at subscribe time
27    SendAtSubscribe = 0,
28    /// Send retained messages at subscribe time only if subscription doesn't exist
29    SendAtSubscribeIfNew = 1,
30    /// Don't send retained messages at subscribe time
31    DoNotSend = 2,
32}
33
34/// Subscription options using bebytes for bit field operations
35/// This demonstrates the hybrid approach for complex packet variable headers
36/// Bit fields are ordered from MSB to LSB (bits 7-0)
37#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
38pub struct SubscriptionOptionsBits {
39    /// Reserved bits (bits 7-6) - must be 0
40    #[bits(2)]
41    pub reserved_bits: u8,
42    /// Retain Handling (bits 5-4)
43    #[bits(2)]
44    pub retain_handling: u8,
45    /// Retain As Published flag (bit 3)
46    #[bits(1)]
47    pub retain_as_published: u8,
48    /// No Local flag (bit 2)
49    #[bits(1)]
50    pub no_local: u8,
51    /// `QoS` level (bits 1-0)
52    #[bits(2)]
53    pub qos: u8,
54}
55
56impl SubscriptionOptionsBits {
57    /// Creates subscription options bits from high-level `SubscriptionOptions`
58    /// Bebytes handles bit field layout, Rust handles type safety and validation
59    #[must_use]
60    pub fn from_options(options: &SubscriptionOptions) -> Self {
61        Self {
62            reserved_bits: 0,
63            retain_handling: options.retain_handling as u8,
64            retain_as_published: u8::from(options.retain_as_published),
65            no_local: u8::from(options.no_local),
66            qos: options.qos as u8,
67        }
68    }
69
70    /// Converts bebytes bit fields back to high-level `SubscriptionOptions`
71    /// Bebytes provides the bits, Rust handles validation and type conversion
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if reserved bits are set, or if `QoS` or retain handling values are invalid
76    pub fn to_options(&self) -> Result<SubscriptionOptions> {
77        // Validate reserved bits are zero
78        if self.reserved_bits != 0 {
79            return Err(MqttError::MalformedPacket(
80                "Reserved bits in subscription options must be 0".to_string(),
81            ));
82        }
83
84        // Validate and convert QoS
85        let qos = match self.qos {
86            0 => QoS::AtMostOnce,
87            1 => QoS::AtLeastOnce,
88            2 => QoS::ExactlyOnce,
89            _ => {
90                return Err(MqttError::MalformedPacket(format!(
91                    "Invalid QoS value in subscription options: {}",
92                    self.qos
93                )))
94            }
95        };
96
97        // Validate and convert retain handling
98        let retain_handling = match self.retain_handling {
99            0 => RetainHandling::SendAtSubscribe,
100            1 => RetainHandling::SendAtSubscribeIfNew,
101            2 => RetainHandling::DoNotSend,
102            _ => {
103                return Err(MqttError::MalformedPacket(format!(
104                    "Invalid retain handling value: {}",
105                    self.retain_handling
106                )))
107            }
108        };
109
110        Ok(SubscriptionOptions {
111            qos,
112            no_local: self.no_local != 0,
113            retain_as_published: self.retain_as_published != 0,
114            retain_handling,
115        })
116    }
117}
118
119impl Default for SubscriptionOptions {
120    fn default() -> Self {
121        Self {
122            qos: QoS::AtMostOnce,
123            no_local: false,
124            retain_as_published: false,
125            retain_handling: RetainHandling::SendAtSubscribe,
126        }
127    }
128}
129
130impl SubscriptionOptions {
131    /// Creates subscription options with the specified `QoS`
132    #[must_use]
133    pub fn new(qos: QoS) -> Self {
134        Self {
135            qos,
136            ..Default::default()
137        }
138    }
139
140    /// Sets the `QoS` level
141    #[must_use]
142    pub fn with_qos(mut self, qos: QoS) -> Self {
143        self.qos = qos;
144        self
145    }
146
147    /// Encodes subscription options as a byte (v5.0)
148    /// Original manual implementation for comparison
149    #[must_use]
150    pub fn encode(&self) -> u8 {
151        let mut byte = self.qos as u8;
152
153        if self.no_local {
154            byte |= 0x04;
155        }
156
157        if self.retain_as_published {
158            byte |= 0x08;
159        }
160
161        byte |= (self.retain_handling as u8) << 4;
162
163        byte
164    }
165
166    /// Encodes subscription options using bebytes (hybrid approach)
167    /// Bebytes handles bit field operations, Rust handles type safety
168    #[must_use]
169    pub fn encode_with_bebytes(&self) -> u8 {
170        let bits = SubscriptionOptionsBits::from_options(self);
171        bits.to_be_bytes()[0]
172    }
173
174    /// Decodes subscription options from a byte (v5.0)
175    /// Original manual implementation for comparison
176    ///
177    /// # Errors
178    ///
179    /// Returns an error if the `QoS` value is invalid
180    pub fn decode(byte: u8) -> Result<Self> {
181        let qos_val = byte & crate::constants::subscription::QOS_MASK;
182        let qos = match qos_val {
183            0 => QoS::AtMostOnce,
184            1 => QoS::AtLeastOnce,
185            2 => QoS::ExactlyOnce,
186            _ => {
187                return Err(MqttError::MalformedPacket(format!(
188                    "Invalid QoS value in subscription options: {qos_val}"
189                )))
190            }
191        };
192
193        let no_local = (byte & crate::constants::subscription::NO_LOCAL_MASK) != 0;
194        let retain_as_published =
195            (byte & crate::constants::subscription::RETAIN_AS_PUBLISHED_MASK) != 0;
196
197        let retain_handling_val = (byte >> crate::constants::subscription::RETAIN_HANDLING_SHIFT)
198            & crate::constants::subscription::QOS_MASK;
199        let retain_handling = match retain_handling_val {
200            0 => RetainHandling::SendAtSubscribe,
201            1 => RetainHandling::SendAtSubscribeIfNew,
202            2 => RetainHandling::DoNotSend,
203            _ => {
204                return Err(MqttError::MalformedPacket(format!(
205                    "Invalid retain handling value: {retain_handling_val}"
206                )))
207            }
208        };
209
210        // Check reserved bits
211        if (byte & crate::constants::subscription::RESERVED_BITS_MASK) != 0 {
212            return Err(MqttError::MalformedPacket(
213                "Reserved bits in subscription options must be 0".to_string(),
214            ));
215        }
216
217        Ok(Self {
218            qos,
219            no_local,
220            retain_as_published,
221            retain_handling,
222        })
223    }
224
225    /// Decodes subscription options using bebytes (hybrid approach)\
226    /// Bebytes handles bit field extraction, Rust handles validation and type conversion
227    ///
228    /// # Errors
229    ///
230    /// Returns an error if the `QoS` value or retain handling is invalid, or reserved bits are set
231    pub fn decode_with_bebytes(byte: u8) -> Result<Self> {
232        let (bits, _consumed) =
233            SubscriptionOptionsBits::try_from_be_bytes(&[byte]).map_err(|e| {
234                MqttError::MalformedPacket(format!("Invalid subscription options byte: {e}"))
235            })?;
236
237        bits.to_options()
238    }
239}
240
241/// Topic filter with subscription options
242#[derive(Debug, Clone, PartialEq, Eq)]
243pub struct TopicFilter {
244    /// Topic filter string (may contain wildcards)
245    pub filter: String,
246    /// Subscription options
247    pub options: SubscriptionOptions,
248}
249
250impl TopicFilter {
251    /// Creates a new topic filter with the specified `QoS`
252    #[must_use]
253    pub fn new(filter: impl Into<String>, qos: QoS) -> Self {
254        Self {
255            filter: filter.into(),
256            options: SubscriptionOptions::new(qos),
257        }
258    }
259
260    /// Creates a new topic filter with custom options
261    #[must_use]
262    pub fn with_options(filter: impl Into<String>, options: SubscriptionOptions) -> Self {
263        Self {
264            filter: filter.into(),
265            options,
266        }
267    }
268}
269
270/// MQTT SUBSCRIBE packet
271#[derive(Debug, Clone)]
272pub struct SubscribePacket {
273    /// Packet identifier
274    pub packet_id: u16,
275    /// Topic filters to subscribe to
276    pub filters: Vec<TopicFilter>,
277    /// SUBSCRIBE properties (v5.0 only)
278    pub properties: Properties,
279}
280
281impl SubscribePacket {
282    /// Creates a new SUBSCRIBE packet
283    #[must_use]
284    pub fn new(packet_id: u16) -> Self {
285        Self {
286            packet_id,
287            filters: Vec::new(),
288            properties: Properties::default(),
289        }
290    }
291
292    /// Adds a topic filter
293    #[must_use]
294    pub fn add_filter(mut self, filter: impl Into<String>, qos: QoS) -> Self {
295        self.filters.push(TopicFilter::new(filter, qos));
296        self
297    }
298
299    /// Adds a topic filter with options
300    #[must_use]
301    pub fn add_filter_with_options(mut self, filter: TopicFilter) -> Self {
302        self.filters.push(filter);
303        self
304    }
305
306    /// Sets the subscription identifier
307    #[must_use]
308    pub fn with_subscription_identifier(mut self, id: u32) -> Self {
309        self.properties.set_subscription_identifier(id);
310        self
311    }
312
313    /// Adds a user property
314    #[must_use]
315    pub fn with_user_property(mut self, key: String, value: String) -> Self {
316        self.properties.add_user_property(key, value);
317        self
318    }
319}
320
321impl MqttPacket for SubscribePacket {
322    fn packet_type(&self) -> PacketType {
323        PacketType::Subscribe
324    }
325
326    fn flags(&self) -> u8 {
327        0x02 // SUBSCRIBE must have flags = 0x02
328    }
329
330    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
331        // Variable header
332        buf.put_u16(self.packet_id);
333
334        // Properties (v5.0)
335        self.properties.encode(buf)?;
336
337        // Payload - topic filters
338        if self.filters.is_empty() {
339            return Err(MqttError::MalformedPacket(
340                "SUBSCRIBE packet must contain at least one topic filter".to_string(),
341            ));
342        }
343
344        for filter in &self.filters {
345            encode_string(buf, &filter.filter)?;
346            buf.put_u8(filter.options.encode());
347        }
348
349        Ok(())
350    }
351
352    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
353        // Validate flags
354        if fixed_header.flags != 0x02 {
355            return Err(MqttError::MalformedPacket(format!(
356                "Invalid SUBSCRIBE flags: expected 0x02, got 0x{:02X}",
357                fixed_header.flags
358            )));
359        }
360
361        // Packet identifier
362        if buf.remaining() < 2 {
363            return Err(MqttError::MalformedPacket(
364                "SUBSCRIBE missing packet identifier".to_string(),
365            ));
366        }
367        let packet_id = buf.get_u16();
368
369        // Properties (v5.0)
370        let properties = Properties::decode(buf)?;
371
372        // Payload - topic filters
373        let mut filters = Vec::new();
374
375        if !buf.has_remaining() {
376            return Err(MqttError::MalformedPacket(
377                "SUBSCRIBE packet must contain at least one topic filter".to_string(),
378            ));
379        }
380
381        while buf.has_remaining() {
382            let filter_str = decode_string(buf)?;
383
384            if !buf.has_remaining() {
385                return Err(MqttError::MalformedPacket(
386                    "Missing subscription options for topic filter".to_string(),
387                ));
388            }
389
390            let options_byte = buf.get_u8();
391            let options = SubscriptionOptions::decode(options_byte)?;
392
393            filters.push(TopicFilter {
394                filter: filter_str,
395                options,
396            });
397        }
398
399        Ok(Self {
400            packet_id,
401            filters,
402            properties,
403        })
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use crate::protocol::v5::properties::PropertyId;
411    use bebytes::BeBytes;
412    use bytes::BytesMut;
413
414    #[cfg(test)]
415    mod hybrid_approach_tests {
416        use super::*;
417        use proptest::prelude::*;
418
419        #[test]
420        fn test_bebytes_vs_manual_encoding_identical() {
421            // Test that bebytes and manual implementations produce identical results
422            let test_cases = vec![
423                SubscriptionOptions::default(),
424                SubscriptionOptions {
425                    qos: QoS::AtLeastOnce,
426                    no_local: true,
427                    retain_as_published: true,
428                    retain_handling: RetainHandling::SendAtSubscribeIfNew,
429                },
430                SubscriptionOptions {
431                    qos: QoS::ExactlyOnce,
432                    no_local: false,
433                    retain_as_published: true,
434                    retain_handling: RetainHandling::DoNotSend,
435                },
436            ];
437
438            for options in test_cases {
439                let manual_encoded = options.encode();
440                let bebytes_encoded = options.encode_with_bebytes();
441
442                assert_eq!(
443                    manual_encoded, bebytes_encoded,
444                    "Manual and bebytes encoding should be identical for options: {options:?}"
445                );
446
447                // Also verify decoding produces same results
448                let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
449                let bebytes_decoded =
450                    SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
451
452                assert_eq!(manual_decoded, bebytes_decoded);
453                assert_eq!(manual_decoded, options);
454            }
455        }
456
457        #[test]
458        fn test_subscription_options_bits_round_trip() {
459            let options = SubscriptionOptions {
460                qos: QoS::AtLeastOnce,
461                no_local: true,
462                retain_as_published: false,
463                retain_handling: RetainHandling::SendAtSubscribeIfNew,
464            };
465
466            let bits = SubscriptionOptionsBits::from_options(&options);
467            let bytes = bits.to_be_bytes();
468            assert_eq!(bytes.len(), 1);
469
470            let (decoded_bits, consumed) =
471                SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
472            assert_eq!(consumed, 1);
473            assert_eq!(decoded_bits, bits);
474
475            let decoded_options = decoded_bits.to_options().unwrap();
476            assert_eq!(decoded_options, options);
477        }
478
479        #[test]
480        fn test_reserved_bits_validation() {
481            // Test that reserved bits being set causes validation errors
482            let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
483
484            // Set reserved bits
485            bits.reserved_bits = 1;
486            assert!(bits.to_options().is_err());
487
488            // Set different reserved bit pattern
489            bits.reserved_bits = 2;
490            assert!(bits.to_options().is_err());
491        }
492
493        #[test]
494        fn test_invalid_qos_validation() {
495            let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
496            bits.qos = 3; // Invalid QoS
497            assert!(bits.to_options().is_err());
498        }
499
500        #[test]
501        fn test_invalid_retain_handling_validation() {
502            let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
503            bits.retain_handling = 3; // Invalid retain handling
504            assert!(bits.to_options().is_err());
505        }
506
507        proptest! {
508            #[test]
509            fn prop_manual_vs_bebytes_encoding_consistency(
510                qos in 0u8..=2,
511                no_local: bool,
512                retain_as_published: bool,
513                retain_handling in 0u8..=2
514            ) {
515                let qos_enum = match qos {
516                    0 => QoS::AtMostOnce,
517                    1 => QoS::AtLeastOnce,
518                    2 => QoS::ExactlyOnce,
519                    _ => unreachable!(),
520                };
521
522                let retain_handling_enum = match retain_handling {
523                    0 => RetainHandling::SendAtSubscribe,
524                    1 => RetainHandling::SendAtSubscribeIfNew,
525                    2 => RetainHandling::DoNotSend,
526                    _ => unreachable!(),
527                };
528
529                let options = SubscriptionOptions {
530                    qos: qos_enum,
531                    no_local,
532                    retain_as_published,
533                    retain_handling: retain_handling_enum,
534                };
535
536                // Both encoding methods should produce identical results
537                let manual_encoded = options.encode();
538                let bebytes_encoded = options.encode_with_bebytes();
539                prop_assert_eq!(manual_encoded, bebytes_encoded);
540
541                // Both decoding methods should produce identical results
542                let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
543                let bebytes_decoded = SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
544                prop_assert_eq!(manual_decoded, bebytes_decoded);
545                prop_assert_eq!(manual_decoded, options);
546            }
547
548            #[test]
549            fn prop_bebytes_bit_field_round_trip(
550                qos in 0u8..=2,
551                no_local: bool,
552                retain_as_published: bool,
553                retain_handling in 0u8..=2
554            ) {
555                let bits = SubscriptionOptionsBits {
556                    reserved_bits: 0,
557                    retain_handling,
558                    retain_as_published: u8::from(retain_as_published),
559                    no_local: u8::from(no_local),
560                    qos,
561                };
562
563                let bytes = bits.to_be_bytes();
564                let (decoded, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
565
566                prop_assert_eq!(consumed, 1);
567                prop_assert_eq!(decoded, bits);
568
569                // Should be able to convert to high-level options
570                let options = decoded.to_options().unwrap();
571                prop_assert_eq!(options.qos as u8, qos);
572                prop_assert_eq!(options.no_local, no_local);
573                prop_assert_eq!(options.retain_as_published, retain_as_published);
574                prop_assert_eq!(options.retain_handling as u8, retain_handling);
575            }
576        }
577    }
578
579    #[test]
580    fn test_subscription_options_encode_decode() {
581        let options = SubscriptionOptions {
582            qos: QoS::AtLeastOnce,
583            no_local: true,
584            retain_as_published: true,
585            retain_handling: RetainHandling::SendAtSubscribeIfNew,
586        };
587
588        let encoded = options.encode();
589        assert_eq!(encoded, 0x1D); // QoS 1 + No Local + RAP + RH 1
590
591        let decoded = SubscriptionOptions::decode(encoded).unwrap();
592        assert_eq!(decoded, options);
593    }
594
595    #[test]
596    fn test_subscribe_basic() {
597        let packet = SubscribePacket::new(123)
598            .add_filter("temperature/+", QoS::AtLeastOnce)
599            .add_filter("humidity/#", QoS::ExactlyOnce);
600
601        assert_eq!(packet.packet_id, 123);
602        assert_eq!(packet.filters.len(), 2);
603        assert_eq!(packet.filters[0].filter, "temperature/+");
604        assert_eq!(packet.filters[0].options.qos, QoS::AtLeastOnce);
605        assert_eq!(packet.filters[1].filter, "humidity/#");
606        assert_eq!(packet.filters[1].options.qos, QoS::ExactlyOnce);
607    }
608
609    #[test]
610    fn test_subscribe_with_options() {
611        let options = SubscriptionOptions {
612            qos: QoS::AtLeastOnce,
613            no_local: true,
614            retain_as_published: false,
615            retain_handling: RetainHandling::DoNotSend,
616        };
617
618        let packet = SubscribePacket::new(456)
619            .add_filter_with_options(TopicFilter::with_options("test/topic", options));
620
621        assert!(packet.filters[0].options.no_local);
622        assert_eq!(
623            packet.filters[0].options.retain_handling,
624            RetainHandling::DoNotSend
625        );
626    }
627
628    #[test]
629    fn test_subscribe_encode_decode() {
630        let packet = SubscribePacket::new(789)
631            .add_filter("sensor/temp", QoS::AtMostOnce)
632            .add_filter("sensor/humidity", QoS::AtLeastOnce)
633            .with_subscription_identifier(42);
634
635        let mut buf = BytesMut::new();
636        packet.encode(&mut buf).unwrap();
637
638        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
639        assert_eq!(fixed_header.packet_type, PacketType::Subscribe);
640        assert_eq!(fixed_header.flags, 0x02);
641
642        let decoded = SubscribePacket::decode_body(&mut buf, &fixed_header).unwrap();
643        assert_eq!(decoded.packet_id, 789);
644        assert_eq!(decoded.filters.len(), 2);
645        assert_eq!(decoded.filters[0].filter, "sensor/temp");
646        assert_eq!(decoded.filters[0].options.qos, QoS::AtMostOnce);
647        assert_eq!(decoded.filters[1].filter, "sensor/humidity");
648        assert_eq!(decoded.filters[1].options.qos, QoS::AtLeastOnce);
649
650        let sub_id = decoded.properties.get(PropertyId::SubscriptionIdentifier);
651        assert!(sub_id.is_some());
652    }
653
654    #[test]
655    fn test_subscribe_invalid_flags() {
656        let mut buf = BytesMut::new();
657        buf.put_u16(123);
658
659        let fixed_header = FixedHeader::new(PacketType::Subscribe, 0x00, 2); // Wrong flags
660        let result = SubscribePacket::decode_body(&mut buf, &fixed_header);
661        assert!(result.is_err());
662    }
663
664    #[test]
665    fn test_subscribe_empty_filters() {
666        let packet = SubscribePacket::new(123);
667
668        let mut buf = BytesMut::new();
669        let result = packet.encode(&mut buf);
670        assert!(result.is_err());
671    }
672}