mqtt5_protocol/packet/
subscribe_options.rs

1use crate::error::{MqttError, Result};
2use crate::prelude::{format, ToString};
3use crate::QoS;
4use bebytes::BeBytes;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[repr(u8)]
8pub enum RetainHandling {
9    SendAtSubscribe = 0,
10    SendAtSubscribeIfNew = 1,
11    DoNotSend = 2,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
15pub struct SubscriptionOptionsBits {
16    #[bits(2)]
17    pub reserved_bits: u8,
18    #[bits(2)]
19    pub retain_handling: u8,
20    #[bits(1)]
21    pub retain_as_published: u8,
22    #[bits(1)]
23    pub no_local: u8,
24    #[bits(2)]
25    pub qos: u8,
26}
27
28impl SubscriptionOptionsBits {
29    #[must_use]
30    pub fn from_options(options: &SubscriptionOptions) -> Self {
31        Self {
32            reserved_bits: 0,
33            retain_handling: options.retain_handling as u8,
34            retain_as_published: u8::from(options.retain_as_published),
35            no_local: u8::from(options.no_local),
36            qos: options.qos as u8,
37        }
38    }
39
40    /// # Errors
41    /// Returns an error if reserved bits are set, or if `QoS` or retain handling values are invalid.
42    pub fn to_options(&self) -> Result<SubscriptionOptions> {
43        if self.reserved_bits != 0 {
44            return Err(MqttError::MalformedPacket(
45                "Reserved bits in subscription options must be 0".to_string(),
46            ));
47        }
48
49        let qos = match self.qos {
50            0 => QoS::AtMostOnce,
51            1 => QoS::AtLeastOnce,
52            2 => QoS::ExactlyOnce,
53            _ => {
54                return Err(MqttError::MalformedPacket(format!(
55                    "Invalid QoS value in subscription options: {}",
56                    self.qos
57                )))
58            }
59        };
60
61        let retain_handling = match self.retain_handling {
62            0 => RetainHandling::SendAtSubscribe,
63            1 => RetainHandling::SendAtSubscribeIfNew,
64            2 => RetainHandling::DoNotSend,
65            _ => {
66                return Err(MqttError::MalformedPacket(format!(
67                    "Invalid retain handling value: {}",
68                    self.retain_handling
69                )))
70            }
71        };
72
73        Ok(SubscriptionOptions {
74            qos,
75            no_local: self.no_local != 0,
76            retain_as_published: self.retain_as_published != 0,
77            retain_handling,
78        })
79    }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub struct SubscriptionOptions {
84    pub qos: QoS,
85    pub no_local: bool,
86    pub retain_as_published: bool,
87    pub retain_handling: RetainHandling,
88}
89
90impl Default for SubscriptionOptions {
91    fn default() -> Self {
92        Self {
93            qos: QoS::AtMostOnce,
94            no_local: false,
95            retain_as_published: false,
96            retain_handling: RetainHandling::SendAtSubscribe,
97        }
98    }
99}
100
101impl SubscriptionOptions {
102    #[must_use]
103    pub fn new(qos: QoS) -> Self {
104        Self {
105            qos,
106            ..Default::default()
107        }
108    }
109
110    #[must_use]
111    pub fn with_qos(mut self, qos: QoS) -> Self {
112        self.qos = qos;
113        self
114    }
115
116    #[must_use]
117    pub fn encode(&self) -> u8 {
118        let mut byte = self.qos as u8;
119
120        if self.no_local {
121            byte |= 0x04;
122        }
123
124        if self.retain_as_published {
125            byte |= 0x08;
126        }
127
128        byte |= (self.retain_handling as u8) << 4;
129
130        byte
131    }
132
133    #[must_use]
134    pub fn encode_with_bebytes(&self) -> u8 {
135        let bits = SubscriptionOptionsBits::from_options(self);
136        bits.to_be_bytes()[0]
137    }
138
139    /// # Errors
140    /// Returns an error if the `QoS` value or retain handling is invalid, or reserved bits are set.
141    pub fn decode(byte: u8) -> Result<Self> {
142        let qos_val = byte & crate::constants::subscription::QOS_MASK;
143        let qos = match qos_val {
144            0 => QoS::AtMostOnce,
145            1 => QoS::AtLeastOnce,
146            2 => QoS::ExactlyOnce,
147            _ => {
148                return Err(MqttError::MalformedPacket(format!(
149                    "Invalid QoS value in subscription options: {qos_val}"
150                )))
151            }
152        };
153
154        let no_local = (byte & crate::constants::subscription::NO_LOCAL_MASK) != 0;
155        let retain_as_published =
156            (byte & crate::constants::subscription::RETAIN_AS_PUBLISHED_MASK) != 0;
157
158        let retain_handling_val = (byte >> crate::constants::subscription::RETAIN_HANDLING_SHIFT)
159            & crate::constants::subscription::QOS_MASK;
160        let retain_handling = match retain_handling_val {
161            0 => RetainHandling::SendAtSubscribe,
162            1 => RetainHandling::SendAtSubscribeIfNew,
163            2 => RetainHandling::DoNotSend,
164            _ => {
165                return Err(MqttError::MalformedPacket(format!(
166                    "Invalid retain handling value: {retain_handling_val}"
167                )))
168            }
169        };
170
171        if (byte & crate::constants::subscription::RESERVED_BITS_MASK) != 0 {
172            return Err(MqttError::MalformedPacket(
173                "Reserved bits in subscription options must be 0".to_string(),
174            ));
175        }
176
177        Ok(Self {
178            qos,
179            no_local,
180            retain_as_published,
181            retain_handling,
182        })
183    }
184
185    /// # Errors
186    /// Returns an error if the `QoS` value or retain handling is invalid, or reserved bits are set.
187    pub fn decode_with_bebytes(byte: u8) -> Result<Self> {
188        let (bits, _consumed) =
189            SubscriptionOptionsBits::try_from_be_bytes(&[byte]).map_err(|e| {
190                MqttError::MalformedPacket(format!("Invalid subscription options byte: {e}"))
191            })?;
192
193        bits.to_options()
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use proptest::prelude::*;
201
202    #[test]
203    fn test_bebytes_vs_manual_encoding_identical() {
204        let test_cases = vec![
205            SubscriptionOptions::default(),
206            SubscriptionOptions {
207                qos: QoS::AtLeastOnce,
208                no_local: true,
209                retain_as_published: true,
210                retain_handling: RetainHandling::SendAtSubscribeIfNew,
211            },
212            SubscriptionOptions {
213                qos: QoS::ExactlyOnce,
214                no_local: false,
215                retain_as_published: true,
216                retain_handling: RetainHandling::DoNotSend,
217            },
218        ];
219
220        for options in test_cases {
221            let manual_encoded = options.encode();
222            let bebytes_encoded = options.encode_with_bebytes();
223
224            assert_eq!(
225                manual_encoded, bebytes_encoded,
226                "Manual and bebytes encoding should be identical for options: {options:?}"
227            );
228
229            let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
230            let bebytes_decoded =
231                SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
232
233            assert_eq!(manual_decoded, bebytes_decoded);
234            assert_eq!(manual_decoded, options);
235        }
236    }
237
238    #[test]
239    fn test_subscription_options_bits_round_trip() {
240        use bebytes::BeBytes;
241        let options = SubscriptionOptions {
242            qos: QoS::AtLeastOnce,
243            no_local: true,
244            retain_as_published: false,
245            retain_handling: RetainHandling::SendAtSubscribeIfNew,
246        };
247
248        let bits = SubscriptionOptionsBits::from_options(&options);
249        let bytes = bits.to_be_bytes();
250        assert_eq!(bytes.len(), 1);
251
252        let (decoded_bits, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
253        assert_eq!(consumed, 1);
254        assert_eq!(decoded_bits, bits);
255
256        let decoded_options = decoded_bits.to_options().unwrap();
257        assert_eq!(decoded_options, options);
258    }
259
260    #[test]
261    fn test_reserved_bits_validation() {
262        let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
263
264        bits.reserved_bits = 1;
265        assert!(bits.to_options().is_err());
266
267        bits.reserved_bits = 2;
268        assert!(bits.to_options().is_err());
269    }
270
271    #[test]
272    fn test_invalid_qos_validation() {
273        let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
274        bits.qos = 3;
275        assert!(bits.to_options().is_err());
276    }
277
278    #[test]
279    fn test_invalid_retain_handling_validation() {
280        let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
281        bits.retain_handling = 3;
282        assert!(bits.to_options().is_err());
283    }
284
285    #[test]
286    fn test_subscription_options_encode_decode() {
287        let options = SubscriptionOptions {
288            qos: QoS::AtLeastOnce,
289            no_local: true,
290            retain_as_published: true,
291            retain_handling: RetainHandling::SendAtSubscribeIfNew,
292        };
293
294        let encoded = options.encode();
295        assert_eq!(encoded, 0x1D);
296
297        let decoded = SubscriptionOptions::decode(encoded).unwrap();
298        assert_eq!(decoded, options);
299    }
300
301    proptest! {
302        #[test]
303        fn prop_manual_vs_bebytes_encoding_consistency(
304            qos in 0u8..=2,
305            no_local: bool,
306            retain_as_published: bool,
307            retain_handling in 0u8..=2
308        ) {
309            let qos_enum = match qos {
310                0 => QoS::AtMostOnce,
311                1 => QoS::AtLeastOnce,
312                2 => QoS::ExactlyOnce,
313                _ => unreachable!(),
314            };
315
316            let retain_handling_enum = match retain_handling {
317                0 => RetainHandling::SendAtSubscribe,
318                1 => RetainHandling::SendAtSubscribeIfNew,
319                2 => RetainHandling::DoNotSend,
320                _ => unreachable!(),
321            };
322
323            let options = SubscriptionOptions {
324                qos: qos_enum,
325                no_local,
326                retain_as_published,
327                retain_handling: retain_handling_enum,
328            };
329
330            let manual_encoded = options.encode();
331            let bebytes_encoded = options.encode_with_bebytes();
332            prop_assert_eq!(manual_encoded, bebytes_encoded);
333
334            let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
335            let bebytes_decoded = SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
336            prop_assert_eq!(manual_decoded, bebytes_decoded);
337            prop_assert_eq!(manual_decoded, options);
338        }
339
340        #[test]
341        fn prop_bebytes_bit_field_round_trip(
342            qos in 0u8..=2,
343            no_local: bool,
344            retain_as_published: bool,
345            retain_handling in 0u8..=2
346        ) {
347            use bebytes::BeBytes;
348            let bits = SubscriptionOptionsBits {
349                reserved_bits: 0,
350                retain_handling,
351                retain_as_published: u8::from(retain_as_published),
352                no_local: u8::from(no_local),
353                qos,
354            };
355
356            let bytes = bits.to_be_bytes();
357            let (decoded, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
358
359            prop_assert_eq!(consumed, 1);
360            prop_assert_eq!(decoded, bits);
361
362            let options = decoded.to_options().unwrap();
363            prop_assert_eq!(options.qos as u8, qos);
364            prop_assert_eq!(options.no_local, no_local);
365            prop_assert_eq!(options.retain_as_published, retain_as_published);
366            prop_assert_eq!(options.retain_handling as u8, retain_handling);
367        }
368    }
369}