1use std::io;
2
3use bytes::Bytes;
4use tokio::io::{AsyncRead, AsyncReadExt};
5
6use super::Header;
7use crate::{
8 read_string, read_u16, write_bytes, write_u16, Encodable, Error, Pid, QoS, QosPid, TopicName,
9};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub struct Publish {
14 pub dup: bool,
15 pub retain: bool,
16 pub qos_pid: QosPid,
17 pub topic_name: TopicName,
18 pub payload: Bytes,
19}
20
21#[cfg(feature = "arbitrary")]
22impl<'a> arbitrary::Arbitrary<'a> for Publish {
23 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
24 Ok(Publish {
25 dup: u.arbitrary()?,
26 qos_pid: u.arbitrary()?,
27 retain: u.arbitrary()?,
28 topic_name: u.arbitrary()?,
29 payload: Bytes::from(Vec::<u8>::arbitrary(u)?),
30 })
31 }
32}
33
34impl Publish {
35 pub fn new(qos_pid: QosPid, topic_name: TopicName, payload: Bytes) -> Self {
36 Publish {
37 dup: false,
38 retain: false,
39 qos_pid,
40 topic_name,
41 payload,
42 }
43 }
44
45 pub async fn decode_async<T: AsyncRead + Unpin>(
46 reader: &mut T,
47 header: Header,
48 ) -> Result<Self, Error> {
49 let mut remaining_len = header.remaining_len as usize;
50 let topic_name = read_string(reader).await?;
51 remaining_len = remaining_len
52 .checked_sub(2 + topic_name.len())
53 .ok_or(Error::InvalidRemainingLength)?;
54 let qos_pid = match header.qos {
55 QoS::Level0 => QosPid::Level0,
56 QoS::Level1 => {
57 remaining_len = remaining_len
58 .checked_sub(2)
59 .ok_or(Error::InvalidRemainingLength)?;
60 QosPid::Level1(Pid::try_from(read_u16(reader).await?)?)
61 }
62 QoS::Level2 => {
63 remaining_len = remaining_len
64 .checked_sub(2)
65 .ok_or(Error::InvalidRemainingLength)?;
66 QosPid::Level2(Pid::try_from(read_u16(reader).await?)?)
67 }
68 };
69 let payload = if remaining_len > 0 {
70 let mut data = vec![0u8; remaining_len];
71 reader.read_exact(&mut data).await?;
72 data
73 } else {
74 Vec::new()
75 };
76 Ok(Publish {
77 dup: header.dup,
78 qos_pid,
79 retain: header.retain,
80 topic_name: TopicName::try_from(topic_name)?,
81 payload: Bytes::from(payload),
82 })
83 }
84}
85
86impl Encodable for Publish {
87 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
88 write_bytes(writer, self.topic_name.as_bytes())?;
89 match self.qos_pid {
90 QosPid::Level0 => {}
91 QosPid::Level1(pid) | QosPid::Level2(pid) => {
92 write_u16(writer, pid.value())?;
93 }
94 }
95 writer.write_all(self.payload.as_ref())?;
96 Ok(())
97 }
98
99 fn encode_len(&self) -> usize {
100 let mut length = 2 + self.topic_name.len();
101 match self.qos_pid {
102 QosPid::Level0 => {}
103 QosPid::Level1(_) | QosPid::Level2(_) => {
104 length += 2;
105 }
106 }
107 length += self.payload.len();
108 length
109 }
110}