1use crate::encoding::{decode_string, encode_string};
2use crate::error::{MqttError, Result};
3use crate::packet::{FixedHeader, MqttPacket, PacketType};
4use crate::prelude::{format, String, ToString, Vec};
5use crate::protocol::v5::properties::Properties;
6use crate::types::ProtocolVersion;
7use crate::QoS;
8use bytes::{Buf, BufMut};
9
10pub use super::subscribe_options::{RetainHandling, SubscriptionOptions, SubscriptionOptionsBits};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct TopicFilter {
14 pub filter: String,
15 pub options: SubscriptionOptions,
16}
17
18impl TopicFilter {
19 #[must_use]
20 pub fn new(filter: impl Into<String>, qos: QoS) -> Self {
21 Self {
22 filter: filter.into(),
23 options: SubscriptionOptions::new(qos),
24 }
25 }
26
27 #[must_use]
28 pub fn with_options(filter: impl Into<String>, options: SubscriptionOptions) -> Self {
29 Self {
30 filter: filter.into(),
31 options,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
37pub struct SubscribePacket {
38 pub packet_id: u16,
39 pub filters: Vec<TopicFilter>,
40 pub properties: Properties,
41 pub protocol_version: u8,
42}
43
44impl SubscribePacket {
45 #[must_use]
46 pub fn new(packet_id: u16) -> Self {
47 Self {
48 packet_id,
49 filters: Vec::new(),
50 properties: Properties::default(),
51 protocol_version: 5,
52 }
53 }
54
55 #[must_use]
56 pub fn new_v311(packet_id: u16) -> Self {
57 Self {
58 packet_id,
59 filters: Vec::new(),
60 properties: Properties::default(),
61 protocol_version: 4,
62 }
63 }
64
65 #[must_use]
66 pub fn add_filter(mut self, filter: impl Into<String>, qos: QoS) -> Self {
67 self.filters.push(TopicFilter::new(filter, qos));
68 self
69 }
70
71 #[must_use]
72 pub fn add_filter_with_options(mut self, filter: TopicFilter) -> Self {
73 self.filters.push(filter);
74 self
75 }
76
77 #[must_use]
78 pub fn with_subscription_identifier(mut self, id: u32) -> Self {
79 self.properties.set_subscription_identifier(id);
80 self
81 }
82
83 #[must_use]
84 pub fn with_user_property(mut self, key: String, value: String) -> Self {
85 self.properties.add_user_property(key, value);
86 self
87 }
88}
89
90impl MqttPacket for SubscribePacket {
91 fn packet_type(&self) -> PacketType {
92 PacketType::Subscribe
93 }
94
95 fn flags(&self) -> u8 {
96 0x02
97 }
98
99 fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
100 buf.put_u16(self.packet_id);
101
102 if self.protocol_version == 5 {
103 self.properties.encode(buf)?;
104 }
105
106 if self.filters.is_empty() {
107 return Err(MqttError::MalformedPacket(
108 "SUBSCRIBE packet must contain at least one topic filter".to_string(),
109 ));
110 }
111
112 for filter in &self.filters {
113 encode_string(buf, &filter.filter)?;
114 if self.protocol_version == 5 {
115 buf.put_u8(filter.options.encode());
116 } else {
117 buf.put_u8(filter.options.qos as u8);
118 }
119 }
120
121 Ok(())
122 }
123
124 fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
125 Self::decode_body_with_version(buf, fixed_header, 5)
126 }
127}
128
129impl SubscribePacket {
130 pub fn decode_body_with_version<B: Buf>(
133 buf: &mut B,
134 fixed_header: &FixedHeader,
135 protocol_version: u8,
136 ) -> Result<Self> {
137 ProtocolVersion::try_from(protocol_version)
138 .map_err(|()| MqttError::UnsupportedProtocolVersion)?;
139
140 if fixed_header.flags != 0x02 {
141 return Err(MqttError::MalformedPacket(format!(
142 "Invalid SUBSCRIBE flags: expected 0x02, got 0x{:02X}",
143 fixed_header.flags
144 )));
145 }
146
147 if buf.remaining() < 2 {
148 return Err(MqttError::MalformedPacket(
149 "SUBSCRIBE missing packet identifier".to_string(),
150 ));
151 }
152 let packet_id = buf.get_u16();
153
154 let properties = if protocol_version == 5 {
155 Properties::decode(buf)?
156 } else {
157 Properties::default()
158 };
159
160 let mut filters = Vec::new();
161
162 if !buf.has_remaining() {
163 return Err(MqttError::MalformedPacket(
164 "SUBSCRIBE packet must contain at least one topic filter".to_string(),
165 ));
166 }
167
168 while buf.has_remaining() {
169 let filter_str = decode_string(buf)?;
170
171 if !buf.has_remaining() {
172 return Err(MqttError::MalformedPacket(
173 "Missing subscription options for topic filter".to_string(),
174 ));
175 }
176
177 let options_byte = buf.get_u8();
178 let options = if protocol_version == 5 {
179 SubscriptionOptions::decode(options_byte)?
180 } else {
181 SubscriptionOptions {
182 qos: QoS::from(options_byte & 0x03),
183 ..Default::default()
184 }
185 };
186
187 filters.push(TopicFilter {
188 filter: filter_str,
189 options,
190 });
191 }
192
193 Ok(Self {
194 packet_id,
195 filters,
196 properties,
197 protocol_version,
198 })
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::protocol::v5::properties::PropertyId;
206 use bytes::BytesMut;
207
208 #[test]
209 fn test_subscribe_basic() {
210 let packet = SubscribePacket::new(123)
211 .add_filter("temperature/+", QoS::AtLeastOnce)
212 .add_filter("humidity/#", QoS::ExactlyOnce);
213
214 assert_eq!(packet.packet_id, 123);
215 assert_eq!(packet.filters.len(), 2);
216 assert_eq!(packet.filters[0].filter, "temperature/+");
217 assert_eq!(packet.filters[0].options.qos, QoS::AtLeastOnce);
218 assert_eq!(packet.filters[1].filter, "humidity/#");
219 assert_eq!(packet.filters[1].options.qos, QoS::ExactlyOnce);
220 }
221
222 #[test]
223 fn test_subscribe_with_options() {
224 let options = SubscriptionOptions {
225 qos: QoS::AtLeastOnce,
226 no_local: true,
227 retain_as_published: false,
228 retain_handling: RetainHandling::DoNotSend,
229 };
230
231 let packet = SubscribePacket::new(456)
232 .add_filter_with_options(TopicFilter::with_options("test/topic", options));
233
234 assert!(packet.filters[0].options.no_local);
235 assert_eq!(
236 packet.filters[0].options.retain_handling,
237 RetainHandling::DoNotSend
238 );
239 }
240
241 #[test]
242 fn test_subscribe_encode_decode() {
243 let packet = SubscribePacket::new(789)
244 .add_filter("sensor/temp", QoS::AtMostOnce)
245 .add_filter("sensor/humidity", QoS::AtLeastOnce)
246 .with_subscription_identifier(42);
247
248 let mut buf = BytesMut::new();
249 packet.encode(&mut buf).unwrap();
250
251 let fixed_header = FixedHeader::decode(&mut buf).unwrap();
252 assert_eq!(fixed_header.packet_type, PacketType::Subscribe);
253 assert_eq!(fixed_header.flags, 0x02);
254
255 let decoded = SubscribePacket::decode_body(&mut buf, &fixed_header).unwrap();
256 assert_eq!(decoded.packet_id, 789);
257 assert_eq!(decoded.filters.len(), 2);
258 assert_eq!(decoded.filters[0].filter, "sensor/temp");
259 assert_eq!(decoded.filters[0].options.qos, QoS::AtMostOnce);
260 assert_eq!(decoded.filters[1].filter, "sensor/humidity");
261 assert_eq!(decoded.filters[1].options.qos, QoS::AtLeastOnce);
262
263 let sub_id = decoded.properties.get(PropertyId::SubscriptionIdentifier);
264 assert!(sub_id.is_some());
265 }
266
267 #[test]
268 fn test_subscribe_invalid_flags() {
269 let mut buf = BytesMut::new();
270 buf.put_u16(123);
271
272 let fixed_header = FixedHeader::new(PacketType::Subscribe, 0x00, 2);
273 let result = SubscribePacket::decode_body(&mut buf, &fixed_header);
274 assert!(result.is_err());
275 }
276
277 #[test]
278 fn test_subscribe_empty_filters() {
279 let packet = SubscribePacket::new(123);
280
281 let mut buf = BytesMut::new();
282 let result = packet.encode(&mut buf);
283 assert!(result.is_err());
284 }
285}