1use alloc::vec::Vec;
2
3use bytes::Bytes;
4
5use crate::{
6 from_read_exact_error, read_string, read_u16, write_string, write_u16, AsyncRead, Encodable,
7 Error, Pid, QoS, QosPid, SyncWrite, TopicName,
8};
9
10use super::Header;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct Publish {
15 pub dup: bool,
16 pub retain: bool,
17 pub qos_pid: QosPid,
18 pub topic_name: TopicName,
19 pub payload: Bytes,
20}
21
22#[cfg(feature = "arbitrary")]
23impl<'a> arbitrary::Arbitrary<'a> for Publish {
24 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
25 Ok(Publish {
26 dup: u.arbitrary()?,
27 qos_pid: u.arbitrary()?,
28 retain: u.arbitrary()?,
29 topic_name: u.arbitrary()?,
30 payload: Bytes::from(Vec::<u8>::arbitrary(u)?),
31 })
32 }
33}
34
35impl Publish {
36 pub fn new(qos_pid: QosPid, topic_name: TopicName, payload: Bytes) -> Self {
37 Publish {
38 dup: false,
39 retain: false,
40 qos_pid,
41 topic_name,
42 payload,
43 }
44 }
45
46 pub async fn decode_async<T: AsyncRead + Unpin>(
47 reader: &mut T,
48 header: Header,
49 ) -> Result<Self, Error> {
50 let mut remaining_len = header.remaining_len as usize;
51 let topic_name = read_string(reader).await?;
52 remaining_len = remaining_len
53 .checked_sub(2 + topic_name.len())
54 .ok_or(Error::InvalidRemainingLength)?;
55 let qos_pid = match header.qos {
56 QoS::Level0 => QosPid::Level0,
57 QoS::Level1 => {
58 remaining_len = remaining_len
59 .checked_sub(2)
60 .ok_or(Error::InvalidRemainingLength)?;
61 QosPid::Level1(Pid::try_from(read_u16(reader).await?)?)
62 }
63 QoS::Level2 => {
64 remaining_len = remaining_len
65 .checked_sub(2)
66 .ok_or(Error::InvalidRemainingLength)?;
67 QosPid::Level2(Pid::try_from(read_u16(reader).await?)?)
68 }
69 };
70 let payload = if remaining_len > 0 {
71 let mut data = alloc::vec![0u8; remaining_len];
72 reader
73 .read_exact(&mut data)
74 .await
75 .map_err(from_read_exact_error)?;
76 data
77 } else {
78 Vec::new()
79 };
80 Ok(Publish {
81 dup: header.dup,
82 qos_pid,
83 retain: header.retain,
84 topic_name: TopicName::try_from(topic_name)?,
85 payload: Bytes::from(payload),
86 })
87 }
88}
89
90impl Encodable for Publish {
91 fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
92 write_string(writer, &self.topic_name)?;
93 match self.qos_pid {
94 QosPid::Level0 => {}
95 QosPid::Level1(pid) | QosPid::Level2(pid) => {
96 write_u16(writer, pid.value())?;
97 }
98 }
99 writer.write_all(self.payload.as_ref())?;
100 Ok(())
101 }
102
103 fn encode_len(&self) -> usize {
104 let mut length = 2 + self.topic_name.len();
105 match self.qos_pid {
106 QosPid::Level0 => {}
107 QosPid::Level1(_) | QosPid::Level2(_) => {
108 length += 2;
109 }
110 }
111 length += self.payload.len();
112 length
113 }
114}