mqtt5_protocol/packet/
subscribe.rs

1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::packet::{FixedHeader, MqttPacket, PacketType};
4use crate::prelude::{format, String, ToString, Vec};
5use crate::protocol::v5::properties::Properties;
6use crate::types::ProtocolVersion;
7use crate::QoS;
8use bytes::{Buf, BufMut};
9
10pub use super::subscribe_options::{RetainHandling, SubscriptionOptions, SubscriptionOptionsBits};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct TopicFilter {
14    pub filter: String,
15    pub options: SubscriptionOptions,
16}
17
18impl TopicFilter {
19    #[must_use]
20    pub fn new(filter: impl Into<String>, qos: QoS) -> Self {
21        Self {
22            filter: filter.into(),
23            options: SubscriptionOptions::new(qos),
24        }
25    }
26
27    #[must_use]
28    pub fn with_options(filter: impl Into<String>, options: SubscriptionOptions) -> Self {
29        Self {
30            filter: filter.into(),
31            options,
32        }
33    }
34}
35
36#[derive(Debug, Clone)]
37pub struct SubscribePacket {
38    pub packet_id: u16,
39    pub filters: Vec<TopicFilter>,
40    pub properties: Properties,
41    pub protocol_version: u8,
42}
43
44impl SubscribePacket {
45    #[must_use]
46    pub fn new(packet_id: u16) -> Self {
47        Self {
48            packet_id,
49            filters: Vec::new(),
50            properties: Properties::default(),
51            protocol_version: 5,
52        }
53    }
54
55    #[must_use]
56    pub fn new_v311(packet_id: u16) -> Self {
57        Self {
58            packet_id,
59            filters: Vec::new(),
60            properties: Properties::default(),
61            protocol_version: 4,
62        }
63    }
64
65    #[must_use]
66    pub fn add_filter(mut self, filter: impl Into<String>, qos: QoS) -> Self {
67        self.filters.push(TopicFilter::new(filter, qos));
68        self
69    }
70
71    #[must_use]
72    pub fn add_filter_with_options(mut self, filter: TopicFilter) -> Self {
73        self.filters.push(filter);
74        self
75    }
76
77    #[must_use]
78    pub fn with_subscription_identifier(mut self, id: u32) -> Self {
79        self.properties.set_subscription_identifier(id);
80        self
81    }
82
83    #[must_use]
84    pub fn with_user_property(mut self, key: String, value: String) -> Self {
85        self.properties.add_user_property(key, value);
86        self
87    }
88}
89
90impl MqttPacket for SubscribePacket {
91    fn packet_type(&self) -> PacketType {
92        PacketType::Subscribe
93    }
94
95    fn flags(&self) -> u8 {
96        0x02
97    }
98
99    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
100        buf.put_u16(self.packet_id);
101
102        if self.protocol_version == 5 {
103            self.properties.encode(buf)?;
104        }
105
106        if self.filters.is_empty() {
107            return Err(MqttError::MalformedPacket(
108                "SUBSCRIBE packet must contain at least one topic filter".to_string(),
109            ));
110        }
111
112        for filter in &self.filters {
113            encode_string(buf, &filter.filter)?;
114            if self.protocol_version == 5 {
115                buf.put_u8(filter.options.encode());
116            } else {
117                buf.put_u8(filter.options.qos as u8);
118            }
119        }
120
121        Ok(())
122    }
123
124    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
125        Self::decode_body_with_version(buf, fixed_header, 5)
126    }
127}
128
129impl SubscribePacket {
130    /// # Errors
131    /// Returns an error if decoding fails.
132    pub fn decode_body_with_version<B: Buf>(
133        buf: &mut B,
134        fixed_header: &FixedHeader,
135        protocol_version: u8,
136    ) -> Result<Self> {
137        ProtocolVersion::try_from(protocol_version)
138            .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
139
140        if fixed_header.flags != 0x02 {
141            return Err(MqttError::MalformedPacket(format!(
142                "Invalid SUBSCRIBE flags: expected 0x02, got 0x{:02X}",
143                fixed_header.flags
144            )));
145        }
146
147        if buf.remaining() < 2 {
148            return Err(MqttError::MalformedPacket(
149                "SUBSCRIBE missing packet identifier".to_string(),
150            ));
151        }
152        let packet_id = buf.get_u16();
153
154        let properties = if protocol_version == 5 {
155            Properties::decode(buf)?
156        } else {
157            Properties::default()
158        };
159
160        let mut filters = Vec::new();
161
162        if !buf.has_remaining() {
163            return Err(MqttError::MalformedPacket(
164                "SUBSCRIBE packet must contain at least one topic filter".to_string(),
165            ));
166        }
167
168        while buf.has_remaining() {
169            let filter_str = decode_string(buf)?;
170
171            if !buf.has_remaining() {
172                return Err(MqttError::MalformedPacket(
173                    "Missing subscription options for topic filter".to_string(),
174                ));
175            }
176
177            let options_byte = buf.get_u8();
178            let options = if protocol_version == 5 {
179                SubscriptionOptions::decode(options_byte)?
180            } else {
181                SubscriptionOptions {
182                    qos: QoS::from(options_byte & 0x03),
183                    ..Default::default()
184                }
185            };
186
187            filters.push(TopicFilter {
188                filter: filter_str,
189                options,
190            });
191        }
192
193        Ok(Self {
194            packet_id,
195            filters,
196            properties,
197            protocol_version,
198        })
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::protocol::v5::properties::PropertyId;
206    use bytes::BytesMut;
207
208    #[test]
209    fn test_subscribe_basic() {
210        let packet = SubscribePacket::new(123)
211            .add_filter("temperature/+", QoS::AtLeastOnce)
212            .add_filter("humidity/#", QoS::ExactlyOnce);
213
214        assert_eq!(packet.packet_id, 123);
215        assert_eq!(packet.filters.len(), 2);
216        assert_eq!(packet.filters[0].filter, "temperature/+");
217        assert_eq!(packet.filters[0].options.qos, QoS::AtLeastOnce);
218        assert_eq!(packet.filters[1].filter, "humidity/#");
219        assert_eq!(packet.filters[1].options.qos, QoS::ExactlyOnce);
220    }
221
222    #[test]
223    fn test_subscribe_with_options() {
224        let options = SubscriptionOptions {
225            qos: QoS::AtLeastOnce,
226            no_local: true,
227            retain_as_published: false,
228            retain_handling: RetainHandling::DoNotSend,
229        };
230
231        let packet = SubscribePacket::new(456)
232            .add_filter_with_options(TopicFilter::with_options("test/topic", options));
233
234        assert!(packet.filters[0].options.no_local);
235        assert_eq!(
236            packet.filters[0].options.retain_handling,
237            RetainHandling::DoNotSend
238        );
239    }
240
241    #[test]
242    fn test_subscribe_encode_decode() {
243        let packet = SubscribePacket::new(789)
244            .add_filter("sensor/temp", QoS::AtMostOnce)
245            .add_filter("sensor/humidity", QoS::AtLeastOnce)
246            .with_subscription_identifier(42);
247
248        let mut buf = BytesMut::new();
249        packet.encode(&mut buf).unwrap();
250
251        let fixed_header = FixedHeader::decode(&mut buf).unwrap();
252        assert_eq!(fixed_header.packet_type, PacketType::Subscribe);
253        assert_eq!(fixed_header.flags, 0x02);
254
255        let decoded = SubscribePacket::decode_body(&mut buf, &fixed_header).unwrap();
256        assert_eq!(decoded.packet_id, 789);
257        assert_eq!(decoded.filters.len(), 2);
258        assert_eq!(decoded.filters[0].filter, "sensor/temp");
259        assert_eq!(decoded.filters[0].options.qos, QoS::AtMostOnce);
260        assert_eq!(decoded.filters[1].filter, "sensor/humidity");
261        assert_eq!(decoded.filters[1].options.qos, QoS::AtLeastOnce);
262
263        let sub_id = decoded.properties.get(PropertyId::SubscriptionIdentifier);
264        assert!(sub_id.is_some());
265    }
266
267    #[test]
268    fn test_subscribe_invalid_flags() {
269        let mut buf = BytesMut::new();
270        buf.put_u16(123);
271
272        let fixed_header = FixedHeader::new(PacketType::Subscribe, 0x00, 2);
273        let result = SubscribePacket::decode_body(&mut buf, &fixed_header);
274        assert!(result.is_err());
275    }
276
277    #[test]
278    fn test_subscribe_empty_filters() {
279        let packet = SubscribePacket::new(123);
280
281        let mut buf = BytesMut::new();
282        let result = packet.encode(&mut buf);
283        assert!(result.is_err());
284    }
285}