1use crate::error::{MqttError, Result};
2use crate::prelude::{format, ToString};
3use crate::QoS;
4use bebytes::BeBytes;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[repr(u8)]
8pub enum RetainHandling {
9 SendAtSubscribe = 0,
10 SendAtSubscribeIfNew = 1,
11 DoNotSend = 2,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
15pub struct SubscriptionOptionsBits {
16 #[bits(2)]
17 pub reserved_bits: u8,
18 #[bits(2)]
19 pub retain_handling: u8,
20 #[bits(1)]
21 pub retain_as_published: u8,
22 #[bits(1)]
23 pub no_local: u8,
24 #[bits(2)]
25 pub qos: u8,
26}
27
28impl SubscriptionOptionsBits {
29 #[must_use]
30 pub fn from_options(options: &SubscriptionOptions) -> Self {
31 Self {
32 reserved_bits: 0,
33 retain_handling: options.retain_handling as u8,
34 retain_as_published: u8::from(options.retain_as_published),
35 no_local: u8::from(options.no_local),
36 qos: options.qos as u8,
37 }
38 }
39
40 pub fn to_options(&self) -> Result<SubscriptionOptions> {
43 if self.reserved_bits != 0 {
44 return Err(MqttError::MalformedPacket(
45 "Reserved bits in subscription options must be 0".to_string(),
46 ));
47 }
48
49 let qos = match self.qos {
50 0 => QoS::AtMostOnce,
51 1 => QoS::AtLeastOnce,
52 2 => QoS::ExactlyOnce,
53 _ => {
54 return Err(MqttError::MalformedPacket(format!(
55 "Invalid QoS value in subscription options: {}",
56 self.qos
57 )))
58 }
59 };
60
61 let retain_handling = match self.retain_handling {
62 0 => RetainHandling::SendAtSubscribe,
63 1 => RetainHandling::SendAtSubscribeIfNew,
64 2 => RetainHandling::DoNotSend,
65 _ => {
66 return Err(MqttError::MalformedPacket(format!(
67 "Invalid retain handling value: {}",
68 self.retain_handling
69 )))
70 }
71 };
72
73 Ok(SubscriptionOptions {
74 qos,
75 no_local: self.no_local != 0,
76 retain_as_published: self.retain_as_published != 0,
77 retain_handling,
78 })
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub struct SubscriptionOptions {
84 pub qos: QoS,
85 pub no_local: bool,
86 pub retain_as_published: bool,
87 pub retain_handling: RetainHandling,
88}
89
90impl Default for SubscriptionOptions {
91 fn default() -> Self {
92 Self {
93 qos: QoS::AtMostOnce,
94 no_local: false,
95 retain_as_published: false,
96 retain_handling: RetainHandling::SendAtSubscribe,
97 }
98 }
99}
100
101impl SubscriptionOptions {
102 #[must_use]
103 pub fn new(qos: QoS) -> Self {
104 Self {
105 qos,
106 ..Default::default()
107 }
108 }
109
110 #[must_use]
111 pub fn with_qos(mut self, qos: QoS) -> Self {
112 self.qos = qos;
113 self
114 }
115
116 #[must_use]
117 pub fn encode(&self) -> u8 {
118 let mut byte = self.qos as u8;
119
120 if self.no_local {
121 byte |= 0x04;
122 }
123
124 if self.retain_as_published {
125 byte |= 0x08;
126 }
127
128 byte |= (self.retain_handling as u8) << 4;
129
130 byte
131 }
132
133 #[must_use]
134 pub fn encode_with_bebytes(&self) -> u8 {
135 let bits = SubscriptionOptionsBits::from_options(self);
136 bits.to_be_bytes()[0]
137 }
138
139 pub fn decode(byte: u8) -> Result<Self> {
142 let qos_val = byte & crate::constants::subscription::QOS_MASK;
143 let qos = match qos_val {
144 0 => QoS::AtMostOnce,
145 1 => QoS::AtLeastOnce,
146 2 => QoS::ExactlyOnce,
147 _ => {
148 return Err(MqttError::MalformedPacket(format!(
149 "Invalid QoS value in subscription options: {qos_val}"
150 )))
151 }
152 };
153
154 let no_local = (byte & crate::constants::subscription::NO_LOCAL_MASK) != 0;
155 let retain_as_published =
156 (byte & crate::constants::subscription::RETAIN_AS_PUBLISHED_MASK) != 0;
157
158 let retain_handling_val = (byte >> crate::constants::subscription::RETAIN_HANDLING_SHIFT)
159 & crate::constants::subscription::QOS_MASK;
160 let retain_handling = match retain_handling_val {
161 0 => RetainHandling::SendAtSubscribe,
162 1 => RetainHandling::SendAtSubscribeIfNew,
163 2 => RetainHandling::DoNotSend,
164 _ => {
165 return Err(MqttError::MalformedPacket(format!(
166 "Invalid retain handling value: {retain_handling_val}"
167 )))
168 }
169 };
170
171 if (byte & crate::constants::subscription::RESERVED_BITS_MASK) != 0 {
172 return Err(MqttError::MalformedPacket(
173 "Reserved bits in subscription options must be 0".to_string(),
174 ));
175 }
176
177 Ok(Self {
178 qos,
179 no_local,
180 retain_as_published,
181 retain_handling,
182 })
183 }
184
185 pub fn decode_with_bebytes(byte: u8) -> Result<Self> {
188 let (bits, _consumed) =
189 SubscriptionOptionsBits::try_from_be_bytes(&[byte]).map_err(|e| {
190 MqttError::MalformedPacket(format!("Invalid subscription options byte: {e}"))
191 })?;
192
193 bits.to_options()
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use proptest::prelude::*;
201
202 #[test]
203 fn test_bebytes_vs_manual_encoding_identical() {
204 let test_cases = vec![
205 SubscriptionOptions::default(),
206 SubscriptionOptions {
207 qos: QoS::AtLeastOnce,
208 no_local: true,
209 retain_as_published: true,
210 retain_handling: RetainHandling::SendAtSubscribeIfNew,
211 },
212 SubscriptionOptions {
213 qos: QoS::ExactlyOnce,
214 no_local: false,
215 retain_as_published: true,
216 retain_handling: RetainHandling::DoNotSend,
217 },
218 ];
219
220 for options in test_cases {
221 let manual_encoded = options.encode();
222 let bebytes_encoded = options.encode_with_bebytes();
223
224 assert_eq!(
225 manual_encoded, bebytes_encoded,
226 "Manual and bebytes encoding should be identical for options: {options:?}"
227 );
228
229 let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
230 let bebytes_decoded =
231 SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
232
233 assert_eq!(manual_decoded, bebytes_decoded);
234 assert_eq!(manual_decoded, options);
235 }
236 }
237
238 #[test]
239 fn test_subscription_options_bits_round_trip() {
240 use bebytes::BeBytes;
241 let options = SubscriptionOptions {
242 qos: QoS::AtLeastOnce,
243 no_local: true,
244 retain_as_published: false,
245 retain_handling: RetainHandling::SendAtSubscribeIfNew,
246 };
247
248 let bits = SubscriptionOptionsBits::from_options(&options);
249 let bytes = bits.to_be_bytes();
250 assert_eq!(bytes.len(), 1);
251
252 let (decoded_bits, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
253 assert_eq!(consumed, 1);
254 assert_eq!(decoded_bits, bits);
255
256 let decoded_options = decoded_bits.to_options().unwrap();
257 assert_eq!(decoded_options, options);
258 }
259
260 #[test]
261 fn test_reserved_bits_validation() {
262 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
263
264 bits.reserved_bits = 1;
265 assert!(bits.to_options().is_err());
266
267 bits.reserved_bits = 2;
268 assert!(bits.to_options().is_err());
269 }
270
271 #[test]
272 fn test_invalid_qos_validation() {
273 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
274 bits.qos = 3;
275 assert!(bits.to_options().is_err());
276 }
277
278 #[test]
279 fn test_invalid_retain_handling_validation() {
280 let mut bits = SubscriptionOptionsBits::from_options(&SubscriptionOptions::default());
281 bits.retain_handling = 3;
282 assert!(bits.to_options().is_err());
283 }
284
285 #[test]
286 fn test_subscription_options_encode_decode() {
287 let options = SubscriptionOptions {
288 qos: QoS::AtLeastOnce,
289 no_local: true,
290 retain_as_published: true,
291 retain_handling: RetainHandling::SendAtSubscribeIfNew,
292 };
293
294 let encoded = options.encode();
295 assert_eq!(encoded, 0x1D);
296
297 let decoded = SubscriptionOptions::decode(encoded).unwrap();
298 assert_eq!(decoded, options);
299 }
300
301 proptest! {
302 #[test]
303 fn prop_manual_vs_bebytes_encoding_consistency(
304 qos in 0u8..=2,
305 no_local: bool,
306 retain_as_published: bool,
307 retain_handling in 0u8..=2
308 ) {
309 let qos_enum = match qos {
310 0 => QoS::AtMostOnce,
311 1 => QoS::AtLeastOnce,
312 2 => QoS::ExactlyOnce,
313 _ => unreachable!(),
314 };
315
316 let retain_handling_enum = match retain_handling {
317 0 => RetainHandling::SendAtSubscribe,
318 1 => RetainHandling::SendAtSubscribeIfNew,
319 2 => RetainHandling::DoNotSend,
320 _ => unreachable!(),
321 };
322
323 let options = SubscriptionOptions {
324 qos: qos_enum,
325 no_local,
326 retain_as_published,
327 retain_handling: retain_handling_enum,
328 };
329
330 let manual_encoded = options.encode();
331 let bebytes_encoded = options.encode_with_bebytes();
332 prop_assert_eq!(manual_encoded, bebytes_encoded);
333
334 let manual_decoded = SubscriptionOptions::decode(manual_encoded).unwrap();
335 let bebytes_decoded = SubscriptionOptions::decode_with_bebytes(bebytes_encoded).unwrap();
336 prop_assert_eq!(manual_decoded, bebytes_decoded);
337 prop_assert_eq!(manual_decoded, options);
338 }
339
340 #[test]
341 fn prop_bebytes_bit_field_round_trip(
342 qos in 0u8..=2,
343 no_local: bool,
344 retain_as_published: bool,
345 retain_handling in 0u8..=2
346 ) {
347 use bebytes::BeBytes;
348 let bits = SubscriptionOptionsBits {
349 reserved_bits: 0,
350 retain_handling,
351 retain_as_published: u8::from(retain_as_published),
352 no_local: u8::from(no_local),
353 qos,
354 };
355
356 let bytes = bits.to_be_bytes();
357 let (decoded, consumed) = SubscriptionOptionsBits::try_from_be_bytes(&bytes).unwrap();
358
359 prop_assert_eq!(consumed, 1);
360 prop_assert_eq!(decoded, bits);
361
362 let options = decoded.to_options().unwrap();
363 prop_assert_eq!(options.qos as u8, qos);
364 prop_assert_eq!(options.no_local, no_local);
365 prop_assert_eq!(options.retain_as_published, retain_as_published);
366 prop_assert_eq!(options.retain_handling as u8, retain_handling);
367 }
368 }
369}