mqtt_proto/v3/
publish.rs

1use std::io;
2
3use bytes::Bytes;
4use tokio::io::{AsyncRead, AsyncReadExt};
5
6use super::Header;
7use crate::{
8    read_string, read_u16, write_bytes, write_u16, Encodable, Error, Pid, QoS, QosPid, TopicName,
9};
10
11/// Publish packet body type.
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub struct Publish {
14    pub dup: bool,
15    pub retain: bool,
16    pub qos_pid: QosPid,
17    pub topic_name: TopicName,
18    pub payload: Bytes,
19}
20
21#[cfg(feature = "arbitrary")]
22impl<'a> arbitrary::Arbitrary<'a> for Publish {
23    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
24        Ok(Publish {
25            dup: u.arbitrary()?,
26            qos_pid: u.arbitrary()?,
27            retain: u.arbitrary()?,
28            topic_name: u.arbitrary()?,
29            payload: Bytes::from(Vec::<u8>::arbitrary(u)?),
30        })
31    }
32}
33
34impl Publish {
35    pub fn new(qos_pid: QosPid, topic_name: TopicName, payload: Bytes) -> Self {
36        Publish {
37            dup: false,
38            retain: false,
39            qos_pid,
40            topic_name,
41            payload,
42        }
43    }
44
45    pub async fn decode_async<T: AsyncRead + Unpin>(
46        reader: &mut T,
47        header: Header,
48    ) -> Result<Self, Error> {
49        let mut remaining_len = header.remaining_len as usize;
50        let topic_name = read_string(reader).await?;
51        remaining_len = remaining_len
52            .checked_sub(2 + topic_name.len())
53            .ok_or(Error::InvalidRemainingLength)?;
54        let qos_pid = match header.qos {
55            QoS::Level0 => QosPid::Level0,
56            QoS::Level1 => {
57                remaining_len = remaining_len
58                    .checked_sub(2)
59                    .ok_or(Error::InvalidRemainingLength)?;
60                QosPid::Level1(Pid::try_from(read_u16(reader).await?)?)
61            }
62            QoS::Level2 => {
63                remaining_len = remaining_len
64                    .checked_sub(2)
65                    .ok_or(Error::InvalidRemainingLength)?;
66                QosPid::Level2(Pid::try_from(read_u16(reader).await?)?)
67            }
68        };
69        let payload = if remaining_len > 0 {
70            let mut data = vec![0u8; remaining_len];
71            reader.read_exact(&mut data).await?;
72            data
73        } else {
74            Vec::new()
75        };
76        Ok(Publish {
77            dup: header.dup,
78            qos_pid,
79            retain: header.retain,
80            topic_name: TopicName::try_from(topic_name)?,
81            payload: Bytes::from(payload),
82        })
83    }
84}
85
86impl Encodable for Publish {
87    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
88        write_bytes(writer, self.topic_name.as_bytes())?;
89        match self.qos_pid {
90            QosPid::Level0 => {}
91            QosPid::Level1(pid) | QosPid::Level2(pid) => {
92                write_u16(writer, pid.value())?;
93            }
94        }
95        writer.write_all(self.payload.as_ref())?;
96        Ok(())
97    }
98
99    fn encode_len(&self) -> usize {
100        let mut length = 2 + self.topic_name.len();
101        match self.qos_pid {
102            QosPid::Level0 => {}
103            QosPid::Level1(_) | QosPid::Level2(_) => {
104                length += 2;
105            }
106        }
107        length += self.payload.len();
108        length
109    }
110}