1use alloc::string::String;
7use alloc::vec::Vec;
8use core::fmt;
9
10use crate::data_types::{
11 DataTypeError, decode_two_byte_int, decode_utf8_string, encode_two_byte_int, encode_utf8_string,
12};
13use crate::packet::{ControlPacketType, FixedHeader};
14use crate::vbi::{VbiError, decode_vbi, encode_vbi};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum CodecError {
19 Vbi(VbiError),
21 DataType(DataTypeError),
23 HeaderTooShort,
25 WrongPacketType(u8),
28 MissingPacketIdentifier,
31 InvalidQoS(u8),
33 RemainingLengthMismatch,
35}
36
37impl From<VbiError> for CodecError {
38 fn from(e: VbiError) -> Self {
39 Self::Vbi(e)
40 }
41}
42
43impl From<DataTypeError> for CodecError {
44 fn from(e: DataTypeError) -> Self {
45 Self::DataType(e)
46 }
47}
48
49impl fmt::Display for CodecError {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 match self {
52 Self::Vbi(e) => write!(f, "VBI: {e}"),
53 Self::DataType(e) => write!(f, "data type: {e}"),
54 Self::HeaderTooShort => f.write_str("packet header too short"),
55 Self::WrongPacketType(t) => write!(f, "wrong packet type {t}"),
56 Self::MissingPacketIdentifier => f.write_str("missing packet identifier"),
57 Self::InvalidQoS(q) => write!(f, "invalid QoS {q}"),
58 Self::RemainingLengthMismatch => f.write_str("remaining length exceeds bytes"),
59 }
60 }
61}
62
63#[cfg(feature = "std")]
64impl std::error::Error for CodecError {}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct PublishPacket {
70 pub dup: bool,
72 pub qos: u8,
74 pub retain: bool,
76 pub topic: String,
78 pub packet_id: Option<u16>,
80 pub properties: Vec<u8>,
83 pub payload: Vec<u8>,
85}
86
87pub fn encode_publish(p: &PublishPacket) -> Result<Vec<u8>, CodecError> {
94 if p.qos > 2 {
95 return Err(CodecError::InvalidQoS(p.qos));
96 }
97 if p.qos > 0 && p.packet_id.is_none() {
98 return Err(CodecError::MissingPacketIdentifier);
99 }
100
101 let mut var_header = encode_utf8_string(&p.topic)?;
103 if p.qos > 0 {
104 let id = p.packet_id.ok_or(CodecError::MissingPacketIdentifier)?;
105 var_header.extend_from_slice(&encode_two_byte_int(id));
106 }
107 let prop_len_u32 =
112 u32::try_from(p.properties.len()).map_err(|_| CodecError::Vbi(VbiError::Malformed))?;
113 let prop_len_vbi = encode_vbi(prop_len_u32).ok_or(CodecError::Vbi(VbiError::Malformed))?;
114 var_header.extend_from_slice(&prop_len_vbi);
115 var_header.extend_from_slice(&p.properties);
116
117 let mut body = var_header;
119 body.extend_from_slice(&p.payload);
120
121 let mut flags = 0u8;
123 if p.dup {
124 flags |= 0b1000;
125 }
126 flags |= (p.qos & 0b11) << 1;
127 if p.retain {
128 flags |= 0b0001;
129 }
130 let byte0 = (ControlPacketType::Publish.to_bits() << 4) | (flags & 0x0F);
131 let mut out = Vec::with_capacity(1 + 4 + body.len());
132 out.push(byte0);
133 #[allow(clippy::cast_possible_truncation)]
134 let remaining_length =
135 u32::try_from(body.len()).map_err(|_| CodecError::Vbi(VbiError::Malformed))?;
136 let vbi_bytes = encode_vbi(remaining_length).ok_or(CodecError::Vbi(VbiError::Malformed))?;
137 out.extend_from_slice(&vbi_bytes);
138 out.extend_from_slice(&body);
139 Ok(out)
140}
141
142pub fn decode_publish(bytes: &[u8]) -> Result<(FixedHeader, PublishPacket), CodecError> {
147 if bytes.is_empty() {
148 return Err(CodecError::HeaderTooShort);
149 }
150 let byte0 = bytes[0];
151 let packet_type_bits = (byte0 >> 4) & 0x0F;
152 if packet_type_bits != ControlPacketType::Publish.to_bits() {
153 return Err(CodecError::WrongPacketType(packet_type_bits));
154 }
155 let flags = byte0 & 0x0F;
156 let qos = (flags >> 1) & 0b11;
157 if qos > 2 {
158 return Err(CodecError::InvalidQoS(qos));
159 }
160 let dup = flags & 0b1000 != 0;
161 let retain = flags & 0b0001 != 0;
162
163 let (remaining_length, vbi_used) = decode_vbi(&bytes[1..])?;
164 let header_total = 1 + vbi_used;
165 let body_end = header_total + remaining_length as usize;
166 if bytes.len() < body_end {
167 return Err(CodecError::RemainingLengthMismatch);
168 }
169 let body = &bytes[header_total..body_end];
170
171 let mut cursor = 0usize;
173 let (topic, used) = decode_utf8_string(&body[cursor..])?;
174 cursor += used;
175 let packet_id = if qos > 0 {
176 let (id, used) = decode_two_byte_int(&body[cursor..])?;
177 cursor += used;
178 Some(id)
179 } else {
180 None
181 };
182 let (prop_len, prop_vbi_used) = decode_vbi(&body[cursor..])?;
187 cursor += prop_vbi_used;
188 let prop_data_end = cursor + prop_len as usize;
189 if body.len() < prop_data_end {
190 return Err(CodecError::RemainingLengthMismatch);
191 }
192 let properties = if prop_len == 0 {
193 Vec::new()
194 } else {
195 body[cursor..prop_data_end].to_vec()
196 };
197 cursor = prop_data_end;
198
199 let payload = body[cursor..].to_vec();
201
202 let header = FixedHeader {
203 packet_type: ControlPacketType::Publish,
204 flags,
205 remaining_length,
206 };
207 Ok((
208 header,
209 PublishPacket {
210 dup,
211 qos,
212 retain,
213 topic,
214 packet_id,
215 properties,
216 payload,
217 },
218 ))
219}
220
221#[cfg(test)]
222#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn publish_qos0_no_packet_id_round_trip() {
228 let p = PublishPacket {
230 dup: false,
231 qos: 0,
232 retain: false,
233 topic: String::from("sensors/temp"),
234 packet_id: None,
235 properties: Vec::new(),
236 payload: alloc::vec![0xDE, 0xAD],
237 };
238 let bytes = encode_publish(&p).expect("encode");
239 let (hdr, parsed) = decode_publish(&bytes).expect("decode");
240 assert_eq!(parsed, p);
241 assert_eq!(hdr.packet_type, ControlPacketType::Publish);
242 assert!(!hdr.dup_flag());
243 assert_eq!(hdr.qos(), 0);
244 }
245
246 #[test]
247 fn publish_qos1_includes_packet_id_round_trip() {
248 let p = PublishPacket {
250 dup: true,
251 qos: 1,
252 retain: true,
253 topic: String::from("foo"),
254 packet_id: Some(0x1234),
255 properties: Vec::new(),
256 payload: b"hello".to_vec(),
257 };
258 let bytes = encode_publish(&p).expect("encode");
259 let (_, parsed) = decode_publish(&bytes).expect("decode");
260 assert_eq!(parsed, p);
261 }
262
263 #[test]
264 fn publish_qos2_round_trip() {
265 let p = PublishPacket {
266 dup: false,
267 qos: 2,
268 retain: false,
269 topic: String::from("a/b/c"),
270 packet_id: Some(42),
271 properties: Vec::new(),
272 payload: alloc::vec![1, 2, 3, 4, 5],
273 };
274 let bytes = encode_publish(&p).expect("encode");
275 let (_, parsed) = decode_publish(&bytes).expect("decode");
276 assert_eq!(parsed.packet_id, Some(42));
277 assert_eq!(parsed.qos, 2);
278 }
279
280 #[test]
281 fn invalid_qos_3_rejected_on_encode() {
282 let mut p = PublishPacket {
284 dup: false,
285 qos: 3,
286 retain: false,
287 topic: String::from("x"),
288 packet_id: None,
289 properties: Vec::new(),
290 payload: Vec::new(),
291 };
292 assert_eq!(encode_publish(&p), Err(CodecError::InvalidQoS(3)));
293 p.qos = 2;
294 p.packet_id = Some(1);
295 assert!(encode_publish(&p).is_ok());
296 }
297
298 #[test]
299 fn missing_packet_id_at_qos1_rejected() {
300 let p = PublishPacket {
302 dup: false,
303 qos: 1,
304 retain: false,
305 topic: String::from("x"),
306 packet_id: None,
307 properties: Vec::new(),
308 payload: Vec::new(),
309 };
310 assert_eq!(encode_publish(&p), Err(CodecError::MissingPacketIdentifier));
311 }
312
313 #[test]
314 fn wrong_packet_type_rejected_on_decode() {
315 let bytes = [0x10u8, 0x02, 0, 0];
317 match decode_publish(&bytes) {
318 Err(CodecError::WrongPacketType(1)) => {}
319 other => panic!("unexpected: {other:?}"),
320 }
321 }
322
323 #[test]
324 fn fixed_header_first_byte_layout_for_publish() {
325 let p = PublishPacket {
327 dup: true,
328 qos: 2,
329 retain: true,
330 topic: String::from("t"),
331 packet_id: Some(1),
332 properties: Vec::new(),
333 payload: Vec::new(),
334 };
335 let bytes = encode_publish(&p).expect("encode");
336 assert_eq!(bytes[0], 0x3D);
338 }
339
340 #[test]
341 fn empty_properties_round_trips_as_empty_vec() {
342 let p = PublishPacket {
344 dup: false,
345 qos: 0,
346 retain: false,
347 topic: String::from("t"),
348 packet_id: None,
349 properties: Vec::new(),
350 payload: alloc::vec![1],
351 };
352 let bytes = encode_publish(&p).expect("encode");
353 let (_, parsed) = decode_publish(&bytes).expect("decode");
354 assert!(parsed.properties.is_empty());
355 }
356
357 #[test]
358 fn non_empty_properties_round_trip_preserves_bytes() {
359 let raw_props_payload = alloc::vec![0x01u8, 0x01, 0x21, 0x00, 0x0A];
363 let p = PublishPacket {
364 dup: false,
365 qos: 0,
366 retain: false,
367 topic: String::from("t"),
368 packet_id: None,
369 properties: raw_props_payload.clone(),
370 payload: alloc::vec![],
371 };
372 let bytes = encode_publish(&p).expect("encode");
373 let (_, parsed) = decode_publish(&bytes).expect("decode");
374 assert_eq!(parsed.properties, raw_props_payload);
375 }
376
377 #[test]
378 fn truncated_remaining_length_decode_fails() {
379 let bytes = [0x30u8, 0x0A, 0, 1, b'x'];
381 assert_eq!(
382 decode_publish(&bytes),
383 Err(CodecError::RemainingLengthMismatch)
384 );
385 }
386
387 #[test]
388 fn empty_input_decode_fails() {
389 assert_eq!(decode_publish(&[]), Err(CodecError::HeaderTooShort));
390 }
391
392 #[test]
393 fn large_payload_encodes_multibyte_remaining_length() {
394 let p = PublishPacket {
397 dup: false,
398 qos: 0,
399 retain: false,
400 topic: String::from("t"),
401 packet_id: None,
402 properties: Vec::new(),
403 payload: alloc::vec![0xAB; 200],
404 };
405 let bytes = encode_publish(&p).expect("encode");
406 assert_eq!(bytes[0], 0x30);
408 let (_, parsed) = decode_publish(&bytes).expect("decode");
411 assert_eq!(parsed.payload.len(), 200);
412 }
413}