mqtt_proto/v3/
publish.rs

1use alloc::vec::Vec;
2
3use bytes::Bytes;
4
5use crate::{
6    from_read_exact_error, read_string, read_u16, write_string, write_u16, AsyncRead, Encodable,
7    Error, Pid, QoS, QosPid, SyncWrite, TopicName,
8};
9
10use super::Header;
11
12/// Publish packet body type.
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct Publish {
15    pub dup: bool,
16    pub retain: bool,
17    pub qos_pid: QosPid,
18    pub topic_name: TopicName,
19    pub payload: Bytes,
20}
21
22#[cfg(feature = "arbitrary")]
23impl<'a> arbitrary::Arbitrary<'a> for Publish {
24    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
25        Ok(Publish {
26            dup: u.arbitrary()?,
27            qos_pid: u.arbitrary()?,
28            retain: u.arbitrary()?,
29            topic_name: u.arbitrary()?,
30            payload: Bytes::from(Vec::<u8>::arbitrary(u)?),
31        })
32    }
33}
34
35impl Publish {
36    pub fn new(qos_pid: QosPid, topic_name: TopicName, payload: Bytes) -> Self {
37        Publish {
38            dup: false,
39            retain: false,
40            qos_pid,
41            topic_name,
42            payload,
43        }
44    }
45
46    pub async fn decode_async<T: AsyncRead + Unpin>(
47        reader: &mut T,
48        header: Header,
49    ) -> Result<Self, Error> {
50        let mut remaining_len = header.remaining_len as usize;
51        let topic_name = read_string(reader).await?;
52        remaining_len = remaining_len
53            .checked_sub(2 + topic_name.len())
54            .ok_or(Error::InvalidRemainingLength)?;
55        let qos_pid = match header.qos {
56            QoS::Level0 => QosPid::Level0,
57            QoS::Level1 => {
58                remaining_len = remaining_len
59                    .checked_sub(2)
60                    .ok_or(Error::InvalidRemainingLength)?;
61                QosPid::Level1(Pid::try_from(read_u16(reader).await?)?)
62            }
63            QoS::Level2 => {
64                remaining_len = remaining_len
65                    .checked_sub(2)
66                    .ok_or(Error::InvalidRemainingLength)?;
67                QosPid::Level2(Pid::try_from(read_u16(reader).await?)?)
68            }
69        };
70        let payload = if remaining_len > 0 {
71            let mut data = alloc::vec![0u8; remaining_len];
72            reader
73                .read_exact(&mut data)
74                .await
75                .map_err(from_read_exact_error)?;
76            data
77        } else {
78            Vec::new()
79        };
80        Ok(Publish {
81            dup: header.dup,
82            qos_pid,
83            retain: header.retain,
84            topic_name: TopicName::try_from(topic_name)?,
85            payload: Bytes::from(payload),
86        })
87    }
88}
89
90impl Encodable for Publish {
91    fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
92        write_string(writer, &self.topic_name)?;
93        match self.qos_pid {
94            QosPid::Level0 => {}
95            QosPid::Level1(pid) | QosPid::Level2(pid) => {
96                write_u16(writer, pid.value())?;
97            }
98        }
99        writer.write_all(self.payload.as_ref())?;
100        Ok(())
101    }
102
103    fn encode_len(&self) -> usize {
104        let mut length = 2 + self.topic_name.len();
105        match self.qos_pid {
106            QosPid::Level0 => {}
107            QosPid::Level1(_) | QosPid::Level2(_) => {
108                length += 2;
109            }
110        }
111        length += self.payload.len();
112        length
113    }
114}