mqtt/packet/
publish.rs

1//! PUBLISH
2
3use std::io::{self, Read, Write};
4
5use crate::control::{FixedHeader, PacketType};
6use crate::packet::{DecodablePacket, PacketError};
7use crate::qos::QualityOfService;
8use crate::topic_name::TopicName;
9use crate::{control::variable_header::PacketIdentifier, TopicNameRef};
10use crate::{Decodable, Encodable};
11
12use super::EncodablePacket;
13
14/// QoS with identifier pairs
15#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)]
16pub enum QoSWithPacketIdentifier {
17    Level0,
18    Level1(u16),
19    Level2(u16),
20}
21
22impl QoSWithPacketIdentifier {
23    pub fn new(qos: QualityOfService, id: u16) -> QoSWithPacketIdentifier {
24        match (qos, id) {
25            (QualityOfService::Level0, _) => QoSWithPacketIdentifier::Level0,
26            (QualityOfService::Level1, id) => QoSWithPacketIdentifier::Level1(id),
27            (QualityOfService::Level2, id) => QoSWithPacketIdentifier::Level2(id),
28        }
29    }
30
31    pub fn split(self) -> (QualityOfService, Option<u16>) {
32        match self {
33            QoSWithPacketIdentifier::Level0 => (QualityOfService::Level0, None),
34            QoSWithPacketIdentifier::Level1(pkid) => (QualityOfService::Level1, Some(pkid)),
35            QoSWithPacketIdentifier::Level2(pkid) => (QualityOfService::Level2, Some(pkid)),
36        }
37    }
38}
39
40/// `PUBLISH` packet
41#[derive(Debug, Eq, PartialEq, Clone)]
42pub struct PublishPacket {
43    fixed_header: FixedHeader,
44    topic_name: TopicName,
45    packet_identifier: Option<PacketIdentifier>,
46    payload: Vec<u8>,
47}
48
49encodable_packet!(PublishPacket(topic_name, packet_identifier, payload));
50
51impl PublishPacket {
52    pub fn new<P: Into<Vec<u8>>>(topic_name: TopicName, qos: QoSWithPacketIdentifier, payload: P) -> PublishPacket {
53        let (qos, pkid) = qos.split();
54        let mut pk = PublishPacket {
55            fixed_header: FixedHeader::new(PacketType::publish(qos), 0),
56            topic_name,
57            packet_identifier: pkid.map(PacketIdentifier),
58            payload: payload.into(),
59        };
60        pk.fix_header_remaining_len();
61        pk
62    }
63
64    pub fn set_dup(&mut self, dup: bool) {
65        self.fixed_header
66            .packet_type
67            .update_flags(|flags| (flags & !(1 << 3)) | (dup as u8) << 3)
68    }
69
70    pub fn dup(&self) -> bool {
71        self.fixed_header.packet_type.flags() & 0x80 != 0
72    }
73
74    pub fn set_qos(&mut self, qos: QoSWithPacketIdentifier) {
75        let (qos, pkid) = qos.split();
76        self.fixed_header
77            .packet_type
78            .update_flags(|flags| (flags & !0b0110) | (qos as u8) << 1);
79        self.packet_identifier = pkid.map(PacketIdentifier);
80        self.fix_header_remaining_len();
81    }
82
83    pub fn qos(&self) -> QoSWithPacketIdentifier {
84        match self.packet_identifier {
85            None => QoSWithPacketIdentifier::Level0,
86            Some(pkid) => {
87                let qos_val = (self.fixed_header.packet_type.flags() & 0b0110) >> 1;
88                match qos_val {
89                    1 => QoSWithPacketIdentifier::Level1(pkid.0),
90                    2 => QoSWithPacketIdentifier::Level2(pkid.0),
91                    _ => unreachable!(),
92                }
93            }
94        }
95    }
96
97    pub fn set_retain(&mut self, ret: bool) {
98        self.fixed_header
99            .packet_type
100            .update_flags(|flags| (flags & !0b0001) | (ret as u8))
101    }
102
103    pub fn retain(&self) -> bool {
104        self.fixed_header.packet_type.flags() & 0b0001 != 0
105    }
106
107    pub fn set_topic_name(&mut self, topic_name: TopicName) {
108        self.topic_name = topic_name;
109        self.fix_header_remaining_len();
110    }
111
112    pub fn topic_name(&self) -> &str {
113        &self.topic_name[..]
114    }
115
116    pub fn payload(&self) -> &[u8] {
117        &self.payload
118    }
119
120    pub fn set_payload<P: Into<Vec<u8>>>(&mut self, payload: P) {
121        self.payload = payload.into();
122        self.fix_header_remaining_len();
123    }
124}
125
126impl DecodablePacket for PublishPacket {
127    type DecodePacketError = std::convert::Infallible;
128
129    fn decode_packet<R: Read>(reader: &mut R, fixed_header: FixedHeader) -> Result<Self, PacketError<Self>> {
130        let topic_name = TopicName::decode(reader)?;
131
132        let qos = (fixed_header.packet_type.flags() & 0b0110) >> 1;
133        let packet_identifier = if qos > 0 {
134            Some(PacketIdentifier::decode(reader)?)
135        } else {
136            None
137        };
138
139        let vhead_len =
140            topic_name.encoded_length() + packet_identifier.as_ref().map(|x| x.encoded_length()).unwrap_or(0);
141        let payload_len = fixed_header.remaining_length - vhead_len;
142
143        let payload = Vec::<u8>::decode_with(reader, Some(payload_len))?;
144
145        Ok(PublishPacket {
146            fixed_header,
147            topic_name,
148            packet_identifier,
149            payload,
150        })
151    }
152}
153
154/// `PUBLISH` packet by reference, for encoding only
155pub struct PublishPacketRef<'a> {
156    fixed_header: FixedHeader,
157    topic_name: &'a TopicNameRef,
158    packet_identifier: Option<PacketIdentifier>,
159    payload: &'a [u8],
160}
161
162impl<'a> PublishPacketRef<'a> {
163    pub fn new(topic_name: &'a TopicNameRef, qos: QoSWithPacketIdentifier, payload: &'a [u8]) -> PublishPacketRef<'a> {
164        let (qos, pkid) = qos.split();
165
166        let mut pk = PublishPacketRef {
167            fixed_header: FixedHeader::new(PacketType::publish(qos), 0),
168            topic_name,
169            packet_identifier: pkid.map(PacketIdentifier),
170            payload,
171        };
172        pk.fix_header_remaining_len();
173        pk
174    }
175
176    fn fix_header_remaining_len(&mut self) {
177        self.fixed_header.remaining_length =
178            self.topic_name.encoded_length() + self.packet_identifier.encoded_length() + self.payload.encoded_length();
179    }
180}
181
182impl EncodablePacket for PublishPacketRef<'_> {
183    fn fixed_header(&self) -> &FixedHeader {
184        &self.fixed_header
185    }
186
187    fn encode_packet<W: Write>(&self, writer: &mut W) -> io::Result<()> {
188        self.topic_name.encode(writer)?;
189        self.packet_identifier.encode(writer)?;
190        self.payload.encode(writer)
191    }
192
193    fn encoded_packet_length(&self) -> u32 {
194        self.topic_name.encoded_length() + self.packet_identifier.encoded_length() + self.payload.encoded_length()
195    }
196}
197
198#[cfg(test)]
199mod test {
200    use super::*;
201
202    use std::io::Cursor;
203
204    use crate::topic_name::TopicName;
205    use crate::{Decodable, Encodable};
206
207    #[test]
208    fn test_publish_packet_basic() {
209        let packet = PublishPacket::new(
210            TopicName::new("a/b".to_owned()).unwrap(),
211            QoSWithPacketIdentifier::Level2(10),
212            b"Hello world!".to_vec(),
213        );
214
215        let mut buf = Vec::new();
216        packet.encode(&mut buf).unwrap();
217
218        let mut decode_buf = Cursor::new(buf);
219        let decoded = PublishPacket::decode(&mut decode_buf).unwrap();
220
221        assert_eq!(packet, decoded);
222    }
223
224    #[test]
225    fn issue56() {
226        let mut packet = PublishPacket::new(
227            TopicName::new("topic").unwrap(),
228            QoSWithPacketIdentifier::Level0,
229            Vec::new(),
230        );
231        assert_eq!(packet.fixed_header().remaining_length, 7);
232
233        packet.set_qos(QoSWithPacketIdentifier::Level1(1));
234        assert_eq!(packet.fixed_header().remaining_length, 9);
235    }
236}