embedded_mqtt/
packet.rs

1use core::{
2    default::Default,
3    convert::TryFrom,
4    cmp::min,
5    result::Result,
6};
7
8use crate::{
9    fixed_header::{self, FixedHeader},
10    variable_header::{self, VariableHeader},
11    payload::{self, Payload},
12    status::Status,
13    error::{DecodeError, EncodeError},
14    codec::{Decodable, Encodable},
15    qos,
16};
17
18#[derive(Debug)]
19#[allow(dead_code)]
20pub struct Packet<'a> {
21    fixed_header: FixedHeader,
22    variable_header: Option<VariableHeader<'a>>,
23    payload: Payload<'a>,
24}
25
26/// A full MQTT packet with fixed header, variable header and payload.
27///
28/// Variable header and payload are optional for some packet types.
29impl<'a> Packet<'a> {
30    /// Create a CONNECT packet.
31    pub fn connect(variable_header: variable_header::connect::Connect<'a>, payload: payload::connect::Connect<'a>) -> Result<Self, EncodeError> {
32        Self::packet(
33            fixed_header::PacketType::Connect,
34            fixed_header::PacketFlags::CONNECT,
35            Some(variable_header::VariableHeader::Connect(variable_header)),
36            payload::Payload::Connect(payload)
37        )
38    }
39
40    /// Create a SUBSCRIBE packet.
41    pub fn subscribe(variable_header: variable_header::packet_identifier::PacketIdentifier, payload: payload::subscribe::Subscribe<'a>) -> Result<Self, EncodeError> {
42        Self::packet(
43            fixed_header::PacketType::Subscribe,
44            fixed_header::PacketFlags::SUBSCRIBE,
45            Some(variable_header::VariableHeader::Subscribe(variable_header)),
46            payload::Payload::Subscribe(payload)
47        )
48    }
49
50    /// Create a PUBLISH packet.
51    pub fn publish(flags: fixed_header::PublishFlags, variable_header: variable_header::publish::Publish<'a>, payload: &'a [u8]) -> Result<Self, EncodeError> {
52        // TODO encode this using type states
53        assert!(flags.qos().expect("valid qos") == qos::QoS::AtMostOnce || variable_header.packet_identifier().is_some());
54
55        Self::packet(
56            fixed_header::PacketType::Publish,
57            flags.into(),
58            Some(variable_header::VariableHeader::Publish(variable_header)),
59            payload::Payload::Bytes(payload)
60        )
61    }
62
63    pub fn puback(variable_header: variable_header::packet_identifier::PacketIdentifier) -> Result<Self, EncodeError> {
64        Self::packet(
65            fixed_header::PacketType::Puback,
66            fixed_header::PacketFlags::PUBACK,
67            Some(variable_header::VariableHeader::Puback(variable_header)),
68            Default::default(),
69        )
70    }
71
72    /// Create a PINGREQ packet.
73    pub fn pingreq() -> Self {
74        Self {
75            fixed_header: FixedHeader::new(
76                fixed_header::PacketType::Pingreq,
77                fixed_header::PacketFlags::PINGREQ,
78                0,
79            ),
80            variable_header: None,
81            payload: Default::default(),
82        }
83    }
84
85    /// Create a PINGRESP packet.
86    pub fn pingresp() -> Self {
87        Self {
88            fixed_header: FixedHeader::new(
89                fixed_header::PacketType::Pingresp,
90                fixed_header::PacketFlags::PINGRESP,
91                0,
92            ),
93            variable_header: None,
94            payload: Default::default(),
95        }
96    }
97
98    /// Create a packet with the given type, flags, variable header and payload.
99    ///
100    /// Constructs a fixed header with the appropriate `len` field for the given
101    /// variable header and payload.
102    fn packet(r#type: fixed_header::PacketType, flags: fixed_header::PacketFlags, variable_header: Option<VariableHeader<'a>>, payload: Payload<'a>) -> Result<Self, EncodeError> {
103        let len = u32::try_from(
104            variable_header.as_ref().map(VariableHeader::encoded_len).unwrap_or(0) +
105            payload.encoded_len()
106        )?;
107
108        Ok(Self {
109            fixed_header: FixedHeader::new(
110                r#type,
111                flags,
112                len,
113            ),
114            variable_header: variable_header,
115            payload: payload,
116        })
117    }
118
119    /// Return a reference to the fixed header of the packet.
120    ///
121    /// The len field of the returned header will be valid.
122    pub fn fixed_header(&self) -> &FixedHeader {
123        &self.fixed_header
124    }
125
126    /// Return a reference to the variable header of the packet.
127    pub fn variable_header(&self) -> &Option<VariableHeader> {
128        &self.variable_header
129    }
130
131    /// Return a reference to the payload of the packet.
132    pub fn payload(&self) -> &Payload {
133        &self.payload
134    }
135}
136
137impl<'a> Decodable<'a> for Packet<'a> {
138    /// Decode any MQTT packet from a pre-allocated buffer.
139    ///
140    /// If an unrecoverable error occurs an `Err(x)` is returned, the caller should
141    /// disconnect and network connection and discard the contents of the connection
142    /// receive buffer.
143    /// 
144    /// Decoding may return an `Ok(Status::Partial(x))` in which case the caller
145    /// should buffer at most `x` more bytes and then attempt decoding again.
146    ///
147    /// If decoding succeeds an `Ok(Status::Complete(x))` will be returned
148    /// containing the number of bytes read from the buffer and the decoded packet.
149    /// The lifetime of the decoded packet is tied to the input buffer.
150    fn decode(bytes: &'a [u8]) -> Result<Status<(usize, Self)>, DecodeError> {
151        let (fixed_header_offset, fixed_header) = read!(FixedHeader::decode, bytes, 0);
152
153        let (variable_header_consumed, variable_header) = if let Some(result) = VariableHeader::decode(fixed_header.r#type(), fixed_header.flags(), &bytes[fixed_header_offset..]) {
154            let (variable_header_offset, variable_header) = complete!(result);
155            (variable_header_offset, Some(variable_header))
156        } else {
157            (0, None)
158        };
159
160        let payload_len = fixed_header.len() as usize - variable_header_consumed;
161
162        let available = bytes.len() - (fixed_header_offset + variable_header_consumed);
163        let needed = payload_len - min(available, payload_len);
164        if needed > 0 {
165            return Ok(Status::Partial(needed));
166        }
167
168        let payload_bytes = &bytes[fixed_header_offset+variable_header_consumed..fixed_header_offset+variable_header_consumed+payload_len];
169
170        let payload = if let Some(result) = Payload::decode(fixed_header.r#type(), payload_bytes) {
171            match result {
172                Err(e) => return Err(e),
173                Ok(Status::Partial(n)) => return Ok(Status::Partial(n)),
174                Ok(Status::Complete((_, payload))) => payload,
175            }
176        } else {
177            payload::Payload::Bytes(payload_bytes)
178        };
179
180        Ok(Status::Complete((fixed_header_offset + fixed_header.len() as usize, Self {
181            fixed_header,
182            variable_header,
183            payload,
184        })))
185    }
186}
187
188impl<'a> Encodable for Packet<'a> {
189    /// Calculate the exact length of the fully encoded packet.
190    ///
191    /// The encode buffer will need to hold at least this number of bytes.
192    fn encoded_len(&self) -> usize {
193        self.fixed_header.encoded_len() + self.fixed_header.len() as usize
194    }
195
196    /// Encode a packet for sending over a network connection.
197    ///
198    /// If encoding fails an `Err(x)` is returned.
199    ///
200    /// If encoding succeeds an `Ok(written)` is returned with the number of
201    /// bytes written to the buffer.
202    fn encode(&self, bytes: &mut [u8]) -> Result<usize, EncodeError> {
203        let mut offset = 0;
204
205        offset = {
206            let o = self.fixed_header.encode(&mut bytes[offset..])?;
207            offset + o
208        };
209
210        if let Some(ref variable_header) = self.variable_header {
211            offset = {
212                let o = variable_header.encode(&mut bytes[offset..])?;
213                offset + o
214            };
215        }
216
217        let offset = {
218            let o = self.payload.encode(&mut bytes[offset..])?;
219            offset + o
220        };
221
222        Ok(offset)
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn encode_publish() {
232        let payload = b"{}";
233        assert_eq!(2, payload.len());
234
235        let mut publish_flags = fixed_header::PublishFlags::default();
236        publish_flags.set_qos(qos::QoS::AtLeastOnce);
237        let publish_id = 2;
238        let publish = Packet::publish(
239            publish_flags,
240            variable_header::publish::Publish::new(
241                "a/b",
242                Some(publish_id),
243            ),
244            payload
245        ).expect("valid packet");
246
247        assert_eq!(11, publish.encoded_len());
248        assert_eq!(2, publish.fixed_header().encoded_len());
249        assert_eq!(9, publish.fixed_header().len());
250        assert_eq!(7, publish.variable_header().as_ref().expect("variable header").encoded_len());
251        assert_eq!(2, publish.payload().encoded_len());
252    }
253
254    #[test]
255    fn encode_subscribe() {
256        let subscribe_id = 1;
257        let sub = Packet::subscribe(
258            variable_header::packet_identifier::PacketIdentifier::new(subscribe_id),
259            payload::subscribe::Subscribe::new(&[
260                ("c/a", qos::QoS::AtMostOnce),
261                ("c/b", qos::QoS::AtLeastOnce),
262                ("c/c", qos::QoS::ExactlyOnce),
263            ]),
264        ).expect("valid packet");
265
266        assert_eq!(22, sub.encoded_len());
267        assert_eq!(2, sub.fixed_header().encoded_len());
268        assert_eq!(20, sub.fixed_header().len());
269        assert_eq!(2, sub.variable_header().as_ref().expect("variable header").encoded_len());
270        assert_eq!(18, sub.payload().encoded_len());
271    }
272}