1use super::*;
2use bytes::{Buf, Bytes};
3
4#[derive(Clone, Debug, PartialEq, Eq, Default)]
6pub struct Publish {
7 pub dup: bool,
8 pub qos: QoS,
9 pub retain: bool,
10 pub topic: Bytes,
11 pub pkid: u16,
12 pub payload: Bytes,
13 pub properties: Option<PublishProperties>,
14}
15
16impl Publish {
17 pub fn new<T: Into<String>, P: Into<Bytes>>(
18 topic: T,
19 qos: QoS,
20 payload: P,
21 properties: Option<PublishProperties>,
22 ) -> Self {
23 let topic = Bytes::copy_from_slice(topic.into().as_bytes());
24 Self {
25 qos,
26 topic,
27 payload: payload.into(),
28 properties,
29 ..Default::default()
30 }
31 }
32
33 pub fn size(&self) -> usize {
34 let len = self.len();
35 let remaining_len_size = len_len(len);
36
37 1 + remaining_len_size + len
38 }
39
40 fn len(&self) -> usize {
41 let mut len = 2 + self.topic.len();
42 if self.qos != QoS::AtMostOnce && self.pkid != 0 {
43 len += 2;
44 }
45
46 if let Some(p) = &self.properties {
47 let properties_len = p.len();
48 let properties_len_len = len_len(properties_len);
49 len += properties_len_len + properties_len;
50 } else {
51 len += 1;
53 }
54
55 len += self.payload.len();
56 len
57 }
58
59 pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Publish, Error> {
60 let qos_num = (fixed_header.byte1 & 0b0110) >> 1;
61 let qos = qos(qos_num).ok_or(Error::InvalidQoS(qos_num))?;
62 let dup = (fixed_header.byte1 & 0b1000) != 0;
63 let retain = (fixed_header.byte1 & 0b0001) != 0;
64
65 let variable_header_index = fixed_header.fixed_header_len;
66 bytes.advance(variable_header_index);
67 let topic = read_mqtt_bytes(&mut bytes)?;
68
69 let pkid = match qos {
71 QoS::AtMostOnce => 0,
72 QoS::AtLeastOnce | QoS::ExactlyOnce => read_u16(&mut bytes)?,
73 };
74
75 if qos != QoS::AtMostOnce && pkid == 0 {
76 return Err(Error::PacketIdZero);
77 }
78
79 let properties = PublishProperties::read(&mut bytes)?;
80 let publish = Publish {
81 dup,
82 retain,
83 qos,
84 pkid,
85 topic,
86 payload: bytes,
87 properties,
88 };
89
90 Ok(publish)
91 }
92
93 pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
94 let len = self.len();
95
96 let dup = self.dup as u8;
97 let qos = self.qos as u8;
98 let retain = self.retain as u8;
99 buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3);
100
101 let count = write_remaining_length(buffer, len)?;
102 write_mqtt_bytes(buffer, &self.topic);
103
104 if self.qos != QoS::AtMostOnce {
105 let pkid = self.pkid;
106 if pkid == 0 {
107 return Err(Error::PacketIdZero);
108 }
109
110 buffer.put_u16(pkid);
111 }
112
113 if let Some(p) = &self.properties {
114 p.write(buffer)?;
115 } else {
116 write_remaining_length(buffer, 0)?;
117 }
118
119 buffer.extend_from_slice(&self.payload);
120
121 Ok(1 + count + len)
122 }
123}
124
125#[derive(Debug, Clone, PartialEq, Eq, Default)]
126pub struct PublishProperties {
127 pub payload_format_indicator: Option<u8>,
128 pub message_expiry_interval: Option<u32>,
129 pub topic_alias: Option<u16>,
130 pub response_topic: Option<String>,
131 pub correlation_data: Option<Bytes>,
132 pub user_properties: Vec<(String, String)>,
133 pub subscription_identifiers: Vec<usize>,
134 pub content_type: Option<String>,
135}
136
137impl PublishProperties {
138 fn len(&self) -> usize {
139 let mut len = 0;
140
141 if self.payload_format_indicator.is_some() {
142 len += 1 + 1;
143 }
144
145 if self.message_expiry_interval.is_some() {
146 len += 1 + 4;
147 }
148
149 if self.topic_alias.is_some() {
150 len += 1 + 2;
151 }
152
153 if let Some(topic) = &self.response_topic {
154 len += 1 + 2 + topic.len()
155 }
156
157 if let Some(data) = &self.correlation_data {
158 len += 1 + 2 + data.len()
159 }
160
161 for (key, value) in self.user_properties.iter() {
162 len += 1 + 2 + key.len() + 2 + value.len();
163 }
164
165 for id in self.subscription_identifiers.iter() {
166 len += 1 + len_len(*id);
167 }
168
169 if let Some(typ) = &self.content_type {
170 len += 1 + 2 + typ.len()
171 }
172
173 len
174 }
175
176 pub fn read(bytes: &mut Bytes) -> Result<Option<PublishProperties>, Error> {
177 let mut payload_format_indicator = None;
178 let mut message_expiry_interval = None;
179 let mut topic_alias = None;
180 let mut response_topic = None;
181 let mut correlation_data = None;
182 let mut user_properties = Vec::new();
183 let mut subscription_identifiers = Vec::new();
184 let mut content_type = None;
185
186 let (properties_len_len, properties_len) = length(bytes.iter())?;
187 bytes.advance(properties_len_len);
188 if properties_len == 0 {
189 return Ok(None);
190 }
191
192 let mut cursor = 0;
193 while cursor < properties_len {
195 let prop = read_u8(bytes)?;
196 cursor += 1;
197
198 match property(prop)? {
199 PropertyType::PayloadFormatIndicator => {
200 payload_format_indicator = Some(read_u8(bytes)?);
201 cursor += 1;
202 }
203 PropertyType::MessageExpiryInterval => {
204 message_expiry_interval = Some(read_u32(bytes)?);
205 cursor += 4;
206 }
207 PropertyType::TopicAlias => {
208 topic_alias = Some(read_u16(bytes)?);
209 cursor += 2;
210 }
211 PropertyType::ResponseTopic => {
212 let topic = read_mqtt_string(bytes)?;
213 cursor += 2 + topic.len();
214 response_topic = Some(topic);
215 }
216 PropertyType::CorrelationData => {
217 let data = read_mqtt_bytes(bytes)?;
218 cursor += 2 + data.len();
219 correlation_data = Some(data);
220 }
221 PropertyType::UserProperty => {
222 let key = read_mqtt_string(bytes)?;
223 let value = read_mqtt_string(bytes)?;
224 cursor += 2 + key.len() + 2 + value.len();
225 user_properties.push((key, value));
226 }
227 PropertyType::SubscriptionIdentifier => {
228 let (id_len, id) = length(bytes.iter())?;
229 cursor += 1 + id_len;
230 bytes.advance(id_len);
231 subscription_identifiers.push(id);
232 }
233 PropertyType::ContentType => {
234 let typ = read_mqtt_string(bytes)?;
235 cursor += 2 + typ.len();
236 content_type = Some(typ);
237 }
238 _ => return Err(Error::InvalidPropertyType(prop)),
239 }
240 }
241
242 Ok(Some(PublishProperties {
243 payload_format_indicator,
244 message_expiry_interval,
245 topic_alias,
246 response_topic,
247 correlation_data,
248 user_properties,
249 subscription_identifiers,
250 content_type,
251 }))
252 }
253
254 pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
255 let len = self.len();
256 write_remaining_length(buffer, len)?;
257
258 if let Some(payload_format_indicator) = self.payload_format_indicator {
259 buffer.put_u8(PropertyType::PayloadFormatIndicator as u8);
260 buffer.put_u8(payload_format_indicator);
261 }
262
263 if let Some(message_expiry_interval) = self.message_expiry_interval {
264 buffer.put_u8(PropertyType::MessageExpiryInterval as u8);
265 buffer.put_u32(message_expiry_interval);
266 }
267
268 if let Some(topic_alias) = self.topic_alias {
269 buffer.put_u8(PropertyType::TopicAlias as u8);
270 buffer.put_u16(topic_alias);
271 }
272
273 if let Some(topic) = &self.response_topic {
274 buffer.put_u8(PropertyType::ResponseTopic as u8);
275 write_mqtt_string(buffer, topic);
276 }
277
278 if let Some(data) = &self.correlation_data {
279 buffer.put_u8(PropertyType::CorrelationData as u8);
280 write_mqtt_bytes(buffer, data);
281 }
282
283 for (key, value) in self.user_properties.iter() {
284 buffer.put_u8(PropertyType::UserProperty as u8);
285 write_mqtt_string(buffer, key);
286 write_mqtt_string(buffer, value);
287 }
288
289 for id in self.subscription_identifiers.iter() {
290 buffer.put_u8(PropertyType::SubscriptionIdentifier as u8);
291 write_remaining_length(buffer, *id)?;
292 }
293
294 if let Some(typ) = &self.content_type {
295 buffer.put_u8(PropertyType::ContentType as u8);
296 write_mqtt_string(buffer, typ);
297 }
298
299 Ok(())
300 }
301}
302
303#[cfg(test)]
304mod test {
305 use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
306 use super::*;
307 use bytes::BytesMut;
308 use pretty_assertions::assert_eq;
309
310 #[test]
311 fn length_calculation() {
312 let mut dummy_bytes = BytesMut::new();
313 let publish_props = PublishProperties {
316 payload_format_indicator: None,
317 message_expiry_interval: None,
318 topic_alias: None,
319 response_topic: None,
320 correlation_data: None,
321 user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
322 subscription_identifiers: vec![1],
323 content_type: None,
324 };
325
326 let publish_pkt = Publish::new(
327 "hello/world",
328 QoS::AtMostOnce,
329 vec![1; 10],
330 Some(publish_props),
331 );
332
333 let size_from_size = publish_pkt.size();
334 let size_from_write = publish_pkt.write(&mut dummy_bytes).unwrap();
335 let size_from_bytes = dummy_bytes.len();
336
337 assert_eq!(size_from_write, size_from_bytes);
338 assert_eq!(size_from_size, size_from_bytes);
339 }
340}