1use super::{
2 BufMut, BytesMut, Error, FixedHeader, PropertyType, QoS, len_len, length, property, qos,
3 read_mqtt_string, read_u8, read_u16, write_mqtt_string, write_remaining_length,
4};
5use bytes::{Buf, Bytes};
6
7#[derive(Clone, Debug, PartialEq, Eq, Default)]
9pub struct Subscribe {
10 pub pkid: u16,
11 pub filters: Vec<Filter>,
12 pub properties: Option<SubscribeProperties>,
13}
14
15impl Subscribe {
16 #[must_use]
17 pub fn new(filter: Filter, properties: Option<SubscribeProperties>) -> Self {
18 Self {
19 filters: vec![filter],
20 properties,
21 ..Default::default()
22 }
23 }
24
25 pub fn new_many<F>(filters: F, properties: Option<SubscribeProperties>) -> Self
26 where
27 F: IntoIterator<Item = Filter>,
28 {
29 Self {
30 filters: filters.into_iter().collect(),
31 properties,
32 ..Default::default()
33 }
34 }
35
36 #[must_use]
37 pub fn size(&self) -> usize {
38 let len = self.len();
39 let remaining_len_size = len_len(len);
40
41 1 + remaining_len_size + len
42 }
43
44 fn len(&self) -> usize {
45 let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len());
46
47 if let Some(p) = &self.properties {
48 let properties_len = p.len();
49 let properties_len_len = len_len(properties_len);
50 len += properties_len_len + properties_len;
51 } else {
52 len += 1;
54 }
55
56 len
57 }
58
59 pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Self, Error> {
60 let variable_header_index = fixed_header.header_len;
61 bytes.advance(variable_header_index);
62
63 let pkid = read_u16(&mut bytes)?;
64 let properties = SubscribeProperties::read(&mut bytes)?;
65
66 let filters = Filter::read(&mut bytes)?;
68
69 match filters.len() {
70 0 => Err(Error::EmptySubscription),
71 _ => Ok(Self {
72 pkid,
73 filters,
74 properties,
75 }),
76 }
77 }
78
79 pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
80 buffer.put_u8(0x82);
82
83 let remaining_len = self.len();
85 let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?;
86
87 buffer.put_u16(self.pkid);
89
90 if let Some(p) = &self.properties {
91 p.write(buffer)?;
92 } else {
93 write_remaining_length(buffer, 0)?;
94 }
95
96 for f in &self.filters {
98 f.write(buffer);
99 }
100
101 Ok(1 + remaining_len_bytes + remaining_len)
102 }
103}
104
105#[derive(Clone, Debug, PartialEq, Eq, Default)]
107pub struct Filter {
108 pub path: String,
109 pub qos: QoS,
110 pub nolocal: bool,
111 pub preserve_retain: bool,
112 pub retain_forward_rule: RetainForwardRule,
113}
114
115impl Filter {
116 pub fn new<T: Into<String>>(topic: T, qos: QoS) -> Self {
117 Self {
118 path: topic.into(),
119 qos,
120 ..Default::default()
121 }
122 }
123
124 const fn len(&self) -> usize {
125 2 + self.path.len() + 1
127 }
128
129 pub fn read(bytes: &mut Bytes) -> Result<Vec<Self>, Error> {
130 let mut filters = Vec::new();
132
133 while bytes.has_remaining() {
134 let path = read_mqtt_string(bytes)?;
135 let options = read_u8(bytes)?;
136 let requested_qos = options & 0b0000_0011;
137
138 let nolocal = (options >> 2) & 0b0000_0001;
139 let nolocal = nolocal != 0;
140
141 let preserve_retain = (options >> 3) & 0b0000_0001;
142 let preserve_retain = preserve_retain != 0;
143
144 let retain_forward_rule = (options >> 4) & 0b0000_0011;
145 let retain_forward_rule = match retain_forward_rule {
146 0 => RetainForwardRule::OnEverySubscribe,
147 1 => RetainForwardRule::OnNewSubscribe,
148 2 => RetainForwardRule::Never,
149 r => return Err(Error::InvalidRetainForwardRule(r)),
150 };
151
152 filters.push(Self {
153 path,
154 qos: qos(requested_qos).ok_or(Error::InvalidQoS(requested_qos))?,
155 nolocal,
156 preserve_retain,
157 retain_forward_rule,
158 });
159 }
160
161 Ok(filters)
162 }
163
164 pub fn write(&self, buffer: &mut BytesMut) {
165 let mut options = 0;
166 options |= self.qos as u8;
167
168 if self.nolocal {
169 options |= 0b0000_0100;
170 }
171
172 if self.preserve_retain {
173 options |= 0b0000_1000;
174 }
175
176 options |= match self.retain_forward_rule {
177 RetainForwardRule::OnEverySubscribe => 0b0000_0000,
178 RetainForwardRule::OnNewSubscribe => 0b0001_0000,
179 RetainForwardRule::Never => 0b0010_0000,
180 };
181
182 write_mqtt_string(buffer, self.path.as_str());
183 buffer.put_u8(options);
184 }
185}
186
187#[derive(Debug, Clone, PartialEq, Eq, Default)]
188pub enum RetainForwardRule {
189 #[default]
190 OnEverySubscribe,
191 OnNewSubscribe,
192 Never,
193}
194
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub struct SubscribeProperties {
197 pub id: Option<usize>,
198 pub user_properties: Vec<(String, String)>,
199}
200
201impl SubscribeProperties {
202 fn len(&self) -> usize {
203 let mut len = 0;
204
205 if let Some(id) = &self.id {
206 len += 1 + len_len(*id);
207 }
208
209 for (key, value) in &self.user_properties {
210 len += 1 + 2 + key.len() + 2 + value.len();
211 }
212
213 len
214 }
215
216 pub fn read(bytes: &mut Bytes) -> Result<Option<Self>, Error> {
217 let mut id = None;
218 let mut user_properties = Vec::new();
219
220 let (properties_len_len, properties_len) = length(bytes.iter())?;
221 bytes.advance(properties_len_len);
222
223 if properties_len == 0 {
224 return Ok(None);
225 }
226
227 let mut cursor = 0;
228 while cursor < properties_len {
230 let prop = read_u8(bytes)?;
231 cursor += 1;
232
233 match property(prop)? {
234 PropertyType::SubscriptionIdentifier => {
235 let (id_len, sub_id) = length(bytes.iter())?;
236 if sub_id == 0 {
237 return Err(Error::ProtocolError);
238 }
239 cursor += id_len;
240 bytes.advance(id_len);
241 id = Some(sub_id);
242 }
243 PropertyType::UserProperty => {
244 let key = read_mqtt_string(bytes)?;
245 let value = read_mqtt_string(bytes)?;
246 cursor += 2 + key.len() + 2 + value.len();
247 user_properties.push((key, value));
248 }
249 _ => return Err(Error::InvalidPropertyType(prop)),
250 }
251 }
252
253 Ok(Some(Self {
254 id,
255 user_properties,
256 }))
257 }
258
259 pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
260 let len = self.len();
261 write_remaining_length(buffer, len)?;
262
263 if let Some(id) = &self.id {
264 if *id == 0 {
265 return Err(Error::ProtocolError);
266 }
267 buffer.put_u8(PropertyType::SubscriptionIdentifier as u8);
268 write_remaining_length(buffer, *id)?;
269 }
270
271 for (key, value) in &self.user_properties {
272 buffer.put_u8(PropertyType::UserProperty as u8);
273 write_mqtt_string(buffer, key);
274 write_mqtt_string(buffer, value);
275 }
276
277 Ok(())
278 }
279}
280
281#[cfg(test)]
282mod test {
283 use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
284 use super::*;
285 use bytes::{Bytes, BytesMut};
286 use pretty_assertions::assert_eq;
287
288 #[test]
289 fn length_calculation() {
290 let mut dummy_bytes = BytesMut::new();
291 let subscribe_props = SubscribeProperties {
294 id: None,
295 user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
296 };
297
298 let subscribe_pkt = Subscribe::new(
299 Filter::new("hello/world", QoS::AtMostOnce),
300 Some(subscribe_props),
301 );
302
303 let size_from_size = subscribe_pkt.size();
304 let size_from_write = subscribe_pkt.write(&mut dummy_bytes).unwrap();
305 let size_from_bytes = dummy_bytes.len();
306
307 assert_eq!(size_from_write, size_from_bytes);
308 assert_eq!(size_from_size, size_from_bytes);
309 }
310
311 #[test]
312 fn read_rejects_subscription_identifier_zero() {
313 let mut bytes = Bytes::from_static(&[0x02, 0x0B, 0x00]);
314 let result = SubscribeProperties::read(&mut bytes);
315
316 assert!(matches!(result, Err(Error::ProtocolError)));
317 }
318
319 #[test]
320 fn write_rejects_subscription_identifier_zero() {
321 let props = SubscribeProperties {
322 id: Some(0),
323 user_properties: vec![],
324 };
325
326 let mut bytes = BytesMut::new();
327 let result = props.write(&mut bytes);
328
329 assert!(matches!(result, Err(Error::ProtocolError)));
330 }
331
332 #[test]
333 fn read_subscription_identifier_and_user_property_parses_both() {
334 let mut bytes = Bytes::from_static(&[
335 0x09, 0x0B, 0x01, 0x26, 0x00, 0x01, b'k', 0x00, 0x01, b'v', ]);
342
343 let properties = SubscribeProperties::read(&mut bytes)
344 .unwrap()
345 .expect("properties should be present");
346
347 assert_eq!(properties.id, Some(1));
348 assert_eq!(
349 properties.user_properties,
350 vec![("k".to_owned(), "v".to_owned())]
351 );
352 }
353}