1use std::convert::TryFrom;
2use std::io::{Cursor, Read};
3
4use bytes::{buf::Buf, Bytes};
5use bytestring::ByteString;
6
7use crate::error::ParseError;
8use crate::packet::*;
9use crate::proto::*;
10
11use super::{ConnectAckFlags, ConnectFlags, FixedHeader, WILL_QOS_SHIFT};
12
13pub(crate) fn read_packet(
14 src: &mut Cursor<Bytes>,
15 header: FixedHeader,
16) -> Result<Packet, ParseError> {
17 match header.packet_type {
18 CONNECT => decode_connect_packet(src),
19 CONNACK => decode_connect_ack_packet(src),
20 PUBLISH => decode_publish_packet(src, header),
21 PUBACK => Ok(Packet::PublishAck {
22 packet_id: read_u16(src)?,
23 }),
24 PUBREC => Ok(Packet::PublishReceived {
25 packet_id: read_u16(src)?,
26 }),
27 PUBREL => Ok(Packet::PublishRelease {
28 packet_id: read_u16(src)?,
29 }),
30 PUBCOMP => Ok(Packet::PublishComplete {
31 packet_id: read_u16(src)?,
32 }),
33 SUBSCRIBE => decode_subscribe_packet(src),
34 SUBACK => decode_subscribe_ack_packet(src),
35 UNSUBSCRIBE => decode_unsubscribe_packet(src),
36 UNSUBACK => Ok(Packet::UnsubscribeAck {
37 packet_id: read_u16(src)?,
38 }),
39 PINGREQ => Ok(Packet::PingRequest),
40 PINGRESP => Ok(Packet::PingResponse),
41 DISCONNECT => Ok(Packet::Disconnect),
42 _ => Err(ParseError::UnsupportedPacketType),
43 }
44}
45
46macro_rules! check_flag {
47 ($flags:expr, $flag:expr) => {
48 ($flags & $flag.bits()) == $flag.bits()
49 };
50}
51
52macro_rules! ensure {
53 ($cond:expr, $e:expr) => {
54 if !($cond) {
55 return Err($e);
56 }
57 };
58 ($cond:expr, $fmt:expr, $($arg:tt)+) => {
59 if !($cond) {
60 return Err($fmt, $($arg)+);
61 }
62 };
63}
64
65pub fn decode_variable_length(src: &[u8]) -> Result<Option<(usize, usize)>, ParseError> {
66 if let Some((len, consumed, more)) = src
67 .iter()
68 .enumerate()
69 .scan((0, true), |state, (idx, x)| {
70 if !state.1 || idx > 3 {
71 return None;
72 }
73 state.0 += ((x & 0x7F) as usize) << (idx * 7);
74 state.1 = x & 0x80 != 0;
75 Some((state.0, idx + 1, state.1))
76 })
77 .last()
78 {
79 ensure!(!more || consumed < 4, ParseError::InvalidLength);
80 return Ok(Some((len, consumed)));
81 }
82
83 Ok(None)
84}
85
86fn decode_connect_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
87 ensure!(src.remaining() >= 10, ParseError::InvalidLength);
88 let len = src.get_u16();
89 if len >= 4 {
90 let mut ver = [0u8; 4];
91 src.read_exact(&mut ver).unwrap();
92 if &ver[..] != b"MQTT" {
93 return Err(ParseError::InvalidProtocol);
94 }
95 } else {
96 return Err(ParseError::InvalidProtocol);
97 }
98
99 let level = src.get_u8();
100 ensure!(
101 level == DEFAULT_MQTT_LEVEL,
102 ParseError::UnsupportedProtocolLevel
103 );
104
105 let flags = src.get_u8();
106 ensure!((flags & 0x01) == 0, ParseError::ConnectReservedFlagSet);
107
108 let keep_alive = src.get_u16();
109 let client_id = decode_utf8_str(src)?;
110
111 ensure!(
112 !client_id.is_empty() || check_flag!(flags, ConnectFlags::CLEAN_SESSION),
113 ParseError::InvalidClientId
114 );
115
116 let topic = if check_flag!(flags, ConnectFlags::WILL) {
117 Some(decode_utf8_str(src)?)
118 } else {
119 None
120 };
121 let message = if check_flag!(flags, ConnectFlags::WILL) {
122 Some(decode_length_bytes(src)?)
123 } else {
124 None
125 };
126 let username = if check_flag!(flags, ConnectFlags::USERNAME) {
127 Some(decode_utf8_str(src)?)
128 } else {
129 None
130 };
131 let password = if check_flag!(flags, ConnectFlags::PASSWORD) {
132 Some(decode_length_bytes(src)?)
133 } else {
134 None
135 };
136 let last_will = if topic.is_some() {
137 Some(LastWill {
138 qos: QoS::from((flags & ConnectFlags::WILL_QOS.bits()) >> WILL_QOS_SHIFT),
139 retain: check_flag!(flags, ConnectFlags::WILL_RETAIN),
140 topic: topic.unwrap(),
141 message: message.unwrap(),
142 })
143 } else {
144 None
145 };
146
147 Ok(Packet::Connect(Connect {
148 protocol: Protocol::MQTT(level),
149 clean_session: check_flag!(flags, ConnectFlags::CLEAN_SESSION),
150 keep_alive,
151 client_id,
152 last_will,
153 username,
154 password,
155 }))
156}
157
158fn decode_connect_ack_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
159 ensure!(src.remaining() >= 2, ParseError::InvalidLength);
160 let flags = src.get_u8();
161 ensure!(
162 (flags & 0b1111_1110) == 0,
163 ParseError::ConnAckReservedFlagSet
164 );
165
166 let return_code = src.get_u8();
167 Ok(Packet::ConnectAck {
168 session_present: check_flag!(flags, ConnectAckFlags::SESSION_PRESENT),
169 return_code: ConnectCode::from(return_code),
170 })
171}
172
173fn decode_publish_packet(
174 src: &mut Cursor<Bytes>,
175 header: FixedHeader,
176) -> Result<Packet, ParseError> {
177 let topic = decode_utf8_str(src)?;
178 let qos = QoS::from((header.packet_flags & 0b0110) >> 1);
179 let packet_id = if qos == QoS::AtMostOnce {
180 None
181 } else {
182 Some(read_u16(src)?)
183 };
184
185 let len = src.remaining();
186 let payload = take(src, len);
187
188 Ok(Packet::Publish(Publish {
189 dup: (header.packet_flags & 0b1000) == 0b1000,
190 qos,
191 retain: (header.packet_flags & 0b0001) == 0b0001,
192 topic,
193 packet_id,
194 payload,
195 }))
196}
197
198fn decode_subscribe_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
199 let packet_id = read_u16(src)?;
200 let mut topic_filters = Vec::new();
201 while src.remaining() > 0 {
202 let topic = decode_utf8_str(src)?;
203 ensure!(src.remaining() >= 1, ParseError::InvalidLength);
204 let qos = QoS::from(src.get_u8() & 0x03);
205 topic_filters.push((topic, qos));
206 }
207
208 Ok(Packet::Subscribe {
209 packet_id,
210 topic_filters,
211 })
212}
213
214fn decode_subscribe_ack_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
215 let packet_id = read_u16(src)?;
216 let status = src
217 .bytes()
218 .map(|code| {
220 let code = code.unwrap();
221 if code == 0x80 {
222 SubscribeReturnCode::Failure
223 } else {
224 SubscribeReturnCode::Success(QoS::from(code & 0x03))
225 }
226 })
227 .collect();
228 Ok(Packet::SubscribeAck { packet_id, status })
229}
230
231fn decode_unsubscribe_packet(src: &mut Cursor<Bytes>) -> Result<Packet, ParseError> {
232 let packet_id = read_u16(src)?;
233 let mut topic_filters = Vec::new();
234 while src.remaining() > 0 {
235 topic_filters.push(decode_utf8_str(src)?);
236 }
237 Ok(Packet::Unsubscribe {
238 packet_id,
239 topic_filters,
240 })
241}
242
243fn decode_length_bytes(src: &mut Cursor<Bytes>) -> Result<Bytes, ParseError> {
244 let len = read_u16(src)? as usize;
245 ensure!(src.remaining() >= len, ParseError::InvalidLength);
246 Ok(take(src, len))
247}
248
249fn decode_utf8_str(src: &mut Cursor<Bytes>) -> Result<ByteString, ParseError> {
250 Ok(ByteString::try_from(decode_length_bytes(src)?)?)
251}
252
253fn take(buf: &mut Cursor<Bytes>, n: usize) -> Bytes {
254 let pos = buf.position() as usize;
255 let ret = buf.get_ref().slice(pos..pos + n);
256 buf.set_position((pos + n) as u64);
257 ret
258}
259
260fn read_u16(src: &mut Cursor<Bytes>) -> Result<u16, ParseError> {
261 ensure!(src.remaining() >= 2, ParseError::InvalidLength);
262 Ok(src.get_u16())
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 macro_rules! assert_decode_packet (
270 ($bytes:expr, $res:expr) => {{
271 let fixed = $bytes.as_ref()[0];
272 let (_len, consumned) = decode_variable_length(&$bytes[1..]).unwrap().unwrap();
273 let hdr = FixedHeader {
274 packet_type: fixed >> 4,
275 packet_flags: fixed & 0xF,
276 remaining_length: $bytes.len() - consumned - 1,
277 };
278 let mut cur = Cursor::new(Bytes::from_static(&$bytes[consumned + 1..]));
279 assert_eq!(read_packet(&mut cur, hdr), Ok($res));
280 }};
281 );
282
283 #[test]
284 fn test_decode_variable_length() {
285 macro_rules! assert_variable_length (
286 ($bytes:expr, $res:expr) => {{
287 assert_eq!(decode_variable_length($bytes), Ok(Some($res)));
288 }};
289
290 ($bytes:expr, $res:expr, $rest:expr) => {{
291 assert_eq!(decode_variable_length($bytes), Ok(Some($res)));
292 }};
293 );
294
295 assert_variable_length!(b"\x7f\x7f", (127, 1), b"\x7f");
296
297 assert_eq!(
299 decode_variable_length(b"\xff\xff\xff\xff\xff\xff"),
300 Err(ParseError::InvalidLength)
301 );
302
303 assert_variable_length!(b"\x00", (0, 1));
304 assert_variable_length!(b"\x7f", (127, 1));
305 assert_variable_length!(b"\x80\x01", (128, 2));
306 assert_variable_length!(b"\xff\x7f", (16383, 2));
307 assert_variable_length!(b"\x80\x80\x01", (16384, 3));
308 assert_variable_length!(b"\xff\xff\x7f", (2097151, 3));
309 assert_variable_length!(b"\x80\x80\x80\x01", (2097152, 4));
310 assert_variable_length!(b"\xff\xff\xff\x7f", (268435455, 4));
311 }
312
313 #[test]
343 fn test_decode_connect_packets() {
344 assert_eq!(
345 decode_connect_packet(&mut Cursor::new(Bytes::from_static(
346 b"\x00\x04MQTT\x04\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass"
347 ))),
348 Ok(Packet::Connect(Connect {
349 protocol: Protocol::MQTT(4),
350 clean_session: false,
351 keep_alive: 60,
352 client_id: ByteString::try_from(Bytes::from_static(b"12345")).unwrap(),
353 last_will: None,
354 username: Some(ByteString::try_from(Bytes::from_static(b"user")).unwrap()),
355 password: Some(Bytes::from(&b"pass"[..])),
356 }))
357 );
358
359 assert_eq!(
360 decode_connect_packet(&mut Cursor::new(Bytes::from_static(
361 b"\x00\x04MQTT\x04\x14\x00\x3C\x00\x0512345\x00\x05topic\x00\x07message"
362 ))),
363 Ok(Packet::Connect(Connect {
364 protocol: Protocol::MQTT(4),
365 clean_session: false,
366 keep_alive: 60,
367 client_id: ByteString::try_from(Bytes::from_static(b"12345")).unwrap(),
368 last_will: Some(LastWill {
369 qos: QoS::ExactlyOnce,
370 retain: false,
371 topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(),
372 message: Bytes::from(&b"message"[..]),
373 }),
374 username: None,
375 password: None,
376 }))
377 );
378
379 assert_eq!(
380 decode_connect_packet(&mut Cursor::new(Bytes::from_static(
381 b"\x00\x02MQ00000000000000000000"
382 ))),
383 Err(ParseError::InvalidProtocol),
384 );
385 assert_eq!(
386 decode_connect_packet(&mut Cursor::new(Bytes::from_static(
387 b"\x00\x10MQ00000000000000000000"
388 ))),
389 Err(ParseError::InvalidProtocol),
390 );
391 assert_eq!(
392 decode_connect_packet(&mut Cursor::new(Bytes::from_static(
393 b"\x00\x04MQAA00000000000000000000"
394 ))),
395 Err(ParseError::InvalidProtocol),
396 );
397 assert_eq!(
398 decode_connect_packet(&mut Cursor::new(Bytes::from_static(
399 b"\x00\x04MQTT\x0300000000000000000000"
400 ))),
401 Err(ParseError::UnsupportedProtocolLevel),
402 );
403 assert_eq!(
404 decode_connect_packet(&mut Cursor::new(Bytes::from_static(
405 b"\x00\x04MQTT\x04\xff00000000000000000000"
406 ))),
407 Err(ParseError::ConnectReservedFlagSet)
408 );
409
410 assert_eq!(
411 decode_connect_ack_packet(&mut Cursor::new(Bytes::from_static(b"\x01\x04"))),
412 Ok(Packet::ConnectAck {
413 session_present: true,
414 return_code: ConnectCode::BadUserNameOrPassword
415 })
416 );
417
418 assert_eq!(
419 decode_connect_ack_packet(&mut Cursor::new(Bytes::from_static(b"\x03\x04"))),
420 Err(ParseError::ConnAckReservedFlagSet)
421 );
422
423 assert_decode_packet!(
424 b"\x20\x02\x01\x04",
425 Packet::ConnectAck {
426 session_present: true,
427 return_code: ConnectCode::BadUserNameOrPassword,
428 }
429 );
430
431 assert_decode_packet!(b"\xe0\x00", Packet::Disconnect);
432 }
433
434 #[test]
435 fn test_decode_publish_packets() {
436 assert_decode_packet!(
442 b"\x3d\x0D\x00\x05topic\x43\x21data",
443 Packet::Publish(Publish {
444 dup: true,
445 retain: true,
446 qos: QoS::ExactlyOnce,
447 topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(),
448 packet_id: Some(0x4321),
449 payload: Bytes::from_static(b"data"),
450 })
451 );
452 assert_decode_packet!(
453 b"\x30\x0b\x00\x05topicdata",
454 Packet::Publish(Publish {
455 dup: false,
456 retain: false,
457 qos: QoS::AtMostOnce,
458 topic: ByteString::try_from(Bytes::from_static(b"topic")).unwrap(),
459 packet_id: None,
460 payload: Bytes::from_static(b"data"),
461 })
462 );
463
464 assert_decode_packet!(
465 b"\x40\x02\x43\x21",
466 Packet::PublishAck { packet_id: 0x4321 }
467 );
468 assert_decode_packet!(
469 b"\x50\x02\x43\x21",
470 Packet::PublishReceived { packet_id: 0x4321 }
471 );
472 assert_decode_packet!(
473 b"\x60\x02\x43\x21",
474 Packet::PublishRelease { packet_id: 0x4321 }
475 );
476 assert_decode_packet!(
477 b"\x70\x02\x43\x21",
478 Packet::PublishComplete { packet_id: 0x4321 }
479 );
480 }
481
482 #[test]
483 fn test_decode_subscribe_packets() {
484 let p = Packet::Subscribe {
485 packet_id: 0x1234,
486 topic_filters: vec![
487 (
488 ByteString::try_from(Bytes::from_static(b"test")).unwrap(),
489 QoS::AtLeastOnce,
490 ),
491 (
492 ByteString::try_from(Bytes::from_static(b"filter")).unwrap(),
493 QoS::ExactlyOnce,
494 ),
495 ],
496 };
497
498 assert_eq!(
499 decode_subscribe_packet(&mut Cursor::new(Bytes::from_static(
500 b"\x12\x34\x00\x04test\x01\x00\x06filter\x02"
501 ))),
502 Ok(p.clone())
503 );
504 assert_decode_packet!(b"\x82\x12\x12\x34\x00\x04test\x01\x00\x06filter\x02", p);
505
506 let p = Packet::SubscribeAck {
507 packet_id: 0x1234,
508 status: vec![
509 SubscribeReturnCode::Success(QoS::AtLeastOnce),
510 SubscribeReturnCode::Failure,
511 SubscribeReturnCode::Success(QoS::ExactlyOnce),
512 ],
513 };
514
515 assert_eq!(
516 decode_subscribe_ack_packet(&mut Cursor::new(Bytes::from_static(
517 b"\x12\x34\x01\x80\x02"
518 ))),
519 Ok(p.clone())
520 );
521 assert_decode_packet!(b"\x90\x05\x12\x34\x01\x80\x02", p);
522
523 let p = Packet::Unsubscribe {
524 packet_id: 0x1234,
525 topic_filters: vec![
526 ByteString::try_from(Bytes::from_static(b"test")).unwrap(),
527 ByteString::try_from(Bytes::from_static(b"filter")).unwrap(),
528 ],
529 };
530
531 assert_eq!(
532 decode_unsubscribe_packet(&mut Cursor::new(Bytes::from_static(
533 b"\x12\x34\x00\x04test\x00\x06filter"
534 ))),
535 Ok(p.clone())
536 );
537 assert_decode_packet!(b"\xa2\x10\x12\x34\x00\x04test\x00\x06filter", p);
538
539 assert_decode_packet!(
540 b"\xb0\x02\x43\x21",
541 Packet::UnsubscribeAck { packet_id: 0x4321 }
542 );
543 }
544
545 #[test]
546 fn test_decode_ping_packets() {
547 assert_decode_packet!(b"\xc0\x00", Packet::PingRequest);
548 assert_decode_packet!(b"\xd0\x00", Packet::PingResponse);
549 }
550}