Skip to main content

mountain_mqtt/
packet_client.rs

1use crate::{
2    codec::{
3        mqtt_reader::{MqttBufReader, MqttReader},
4        mqtt_writer::{MqttBufWriter, MqttWriter},
5        write,
6    },
7    data::packet_type::PacketType,
8    error::{PacketReadError, PacketWriteError},
9    packets::{packet::Packet, packet_generic::PacketGeneric},
10};
11
12#[allow(async_fn_in_trait)]
13pub trait Connection {
14    // Send and flush all data in `buf`
15    async fn send(&mut self, buf: &[u8]) -> Result<(), PacketWriteError>;
16
17    // Receive into buffer, waiting to fill it. This may need to await more data.
18    async fn receive(&mut self, buf: &mut [u8]) -> Result<(), PacketReadError>;
19
20    /// If no data at all is ready, then return immediately with `Ok(false)`,
21    /// leaving the underlying stream and `buf` unaltered.
22    /// If any data is ready  then receive into buffer, waiting to fill it,
23    /// then return `Ok(true)`.
24    /// Note that this method can still await data, because it is only required
25    /// to return `Ok(false)` in the case where no data at all is ready. If less
26    /// data is available than `buf.len()`, the method may still proceed, reading
27    /// the available data and then waiting for more to fill `buf`.
28    /// This is fairly well suited to MQTT packets - first call `receive_if_ready`
29    /// with a `buf` of length 1 - if this returns `Ok(false)` then sleep for
30    /// a reasonable interval and try again. If it returns `Ok(true)`, continue
31    /// reading a whole packet using `receive`, using larger buffers if possible.
32    /// While this may lead to an indefinite await for the rest of the packet, this
33    /// should only occur in the case of a malicious server or poor connection, not
34    /// in the most common case where there is a long gap between incoming packets,
35    /// but once the first byte of a packet is received the rest is received quickly
36    /// afterwards.
37    /// This approach is used to make it easier to use a variety of underlying streams,
38    /// since we only need something like embedded-async's `ReadReady` trait, or
39    /// tokio's `TCPStream.try_read`
40    /// More sophisticated approaches are definitely possible.
41    async fn receive_if_ready(&mut self, buf: &mut [u8]) -> Result<bool, PacketReadError>;
42}
43
44pub struct PacketClient<'a, C> {
45    connection: C,
46    buf: &'a mut [u8],
47}
48
49// struct PositionBuf<'a> {}
50
51impl<'a, C> PacketClient<'a, C>
52where
53    C: Connection,
54{
55    pub fn new(connection: C, buf: &'a mut [u8]) -> Self {
56        Self { connection, buf }
57    }
58
59    pub async fn send<P>(&mut self, packet: P) -> Result<(), PacketWriteError>
60    where
61        P: Packet + write::Write,
62    {
63        let len = {
64            let mut r = MqttBufWriter::new(self.buf);
65            r.put(&packet)?;
66            r.position()
67        };
68        self.connection.send(&self.buf[0..len]).await
69    }
70
71    pub async fn receive<const P: usize, const W: usize, const S: usize>(
72        &mut self,
73    ) -> Result<PacketGeneric<'_, P, W, S>, PacketReadError> {
74        // First, try to read one byte with blocking
75        self.connection.receive(&mut self.buf[0..1]).await?;
76
77        // Then the rest of the packet
78        self.receive_rest_of_packet().await
79    }
80
81    pub async fn receive_if_ready<const P: usize, const W: usize, const S: usize>(
82        &mut self,
83    ) -> Result<Option<PacketGeneric<'_, P, W, S>>, PacketReadError> {
84        // First, try to read one byte without blocking - if this returns false, no packet is ready
85        // and we can return immediately to avoid blocking
86        let packet_started = self
87            .connection
88            .receive_if_ready(&mut self.buf[0..1])
89            .await?;
90
91        if !packet_started {
92            return Ok(None);
93        }
94
95        // We have packet, and its first byte is in our buffer, so receive the rest of the packet
96        let packet = self.receive_rest_of_packet().await?;
97        Ok(Some(packet))
98    }
99
100    async fn receive_rest_of_packet<const P: usize, const W: usize, const S: usize>(
101        &mut self,
102    ) -> Result<PacketGeneric<'_, P, W, S>, PacketReadError> {
103        let mut position: usize = 1;
104
105        // Check first header byte is valid, if not we can error early without
106        // trying to read the rest of a packet
107        if !PacketType::is_valid_first_header_byte(self.buf[0]) {
108            return Err(PacketReadError::InvalidPacketType);
109        }
110
111        // Read up to 4 bytes into buffer as variable u32
112        // First byte always exists
113        self.connection
114            .receive(&mut self.buf[position..position + 1])
115            .await?;
116        position += 1;
117
118        // Read up to 3 more bytes looking for the end of the encoded length
119        for _extra in 0..3 {
120            if self.buf[position - 1] & 128 == 0 {
121                break;
122            } else {
123                self.connection
124                    .receive(&mut self.buf[position..position + 1])
125                    .await?;
126                position += 1;
127            }
128        }
129
130        // Error if we didn't see the end of the length
131        if self.buf[position - 1] & 128 != 0 {
132            return Err(PacketReadError::InvalidVariableByteIntegerEncoding);
133        }
134
135        // We have a valid length, decode it
136        let remaining_length = {
137            let mut r = MqttBufReader::new(&self.buf[1..position]);
138            r.get_variable_u32()?
139        } as usize;
140
141        // If packet will not fit in buffer, error
142        if position + remaining_length > self.buf.len() {
143            return Err(PacketReadError::PacketTooLargeForBuffer);
144        }
145
146        // Read the rest of the packet
147        self.connection
148            .receive(&mut self.buf[position..position + remaining_length])
149            .await?;
150        position += remaining_length;
151
152        // We can now decode the packet from the buffer
153        let packet_buf = &mut self.buf[0..position];
154        let mut packet_reader = MqttBufReader::new(packet_buf);
155        let packet_generic = packet_reader.get()?;
156
157        Ok(packet_generic)
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::{
165        codec::mqtt_reader::MqttBufReader,
166        data::{
167            packet_identifier::PacketIdentifier,
168            property::{ConnectProperty, SubscribeProperty},
169            quality_of_service::QualityOfService,
170        },
171        packets::{
172            connect::Connect,
173            pingreq::Pingreq,
174            subscribe::{Subscribe, SubscriptionRequest},
175        },
176    };
177    use heapless::Vec;
178
179    const ENCODED_PINGRESP: [u8; 2] = [0xC0, 0x00];
180    const INVALID_PACKET_TYPE: [u8; 2] = [0xC1, 0x00];
181    const INVALID_LENGTH: [u8; 5] = [0xC0, 0x80, 0x80, 0x80, 0x80];
182    // Packet has first header byte then a length of 16, implying total packet length of 18 (after the first header byte and the single byte length)
183    const ENCODED_IMPLIES_PACKET_LENGTH_18: [u8; 2] = [0xC0, 0x10];
184
185    const ENCODED_SUBSCRIBE: [u8; 30] = [
186        0x82, 0x1C, 0x15, 0x38, 0x03, 0x0B, 0x80, 0x13, 0x00, 0x0A, 0x74, 0x65, 0x73, 0x74, 0x2f,
187        0x74, 0x6f, 0x70, 0x69, 0x63, 0x00, 0x00, 0x06, 0x68, 0x65, 0x68, 0x65, 0x2F, 0x23, 0x01,
188    ];
189
190    const ENCODED_CONNECT: [u8; 18] = [
191        0x10, 0x10, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x02, 0x00, 0x3c, 0x03, 0x21, 0x00,
192        0x14, 0x00, 0x00,
193    ];
194
195    // Copy of valid ENCODED_CONNECT above, except that it has a "remaining length" in the
196    // header byte that is 1 byte too long, and so should produce an incorrect packet length error. Note that we need
197    // to also add a padding byte to the data so that the client can attempt to read the whole expected packet buffer and get
198    // as far as trying to then decode it and encounter a mismatch in the packet_generic Read implementation
199    const ENCODED_CONNECT_INCORRECT_PACKET_LENGTH: [u8; 19] = [
200        0x10, 0x11, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x02, 0x00, 0x3c, 0x03, 0x21, 0x00,
201        0x14, 0x00, 0x00, 0x00,
202    ];
203
204    fn example_subscribe_packet<'a>() -> Subscribe<'a, 16, 16> {
205        let primary_request = SubscriptionRequest::new("test/topic", QualityOfService::Qos0);
206        let mut additional_requests = Vec::new();
207        additional_requests
208            .push(SubscriptionRequest::new("hehe/#", QualityOfService::Qos1))
209            .unwrap();
210        let mut properties = Vec::new();
211        properties
212            .push(SubscribeProperty::SubscriptionIdentifier(2432.into()))
213            .unwrap();
214        let packet = Subscribe::new(
215            PacketIdentifier(5432),
216            primary_request,
217            additional_requests,
218            properties,
219        );
220
221        packet
222    }
223
224    fn example_connect_packet<'a>() -> Connect<'a, 16, 16> {
225        let mut packet = Connect::new(60, None, None, "", true, None, Vec::new());
226        packet
227            .properties
228            .push(ConnectProperty::ReceiveMaximum(20.into()))
229            .unwrap();
230        packet
231    }
232
233    struct BufferConnection<'a> {
234        reader: MqttBufReader<'a>,
235        writer: MqttBufWriter<'a>,
236    }
237
238    impl<'a> BufferConnection<'a> {
239        pub fn new(read_buf: &'a [u8], write_buf: &'a mut [u8]) -> Self {
240            let reader = MqttBufReader::new(read_buf);
241            let writer = MqttBufWriter::new(write_buf);
242            BufferConnection { reader, writer }
243        }
244    }
245
246    impl Connection for BufferConnection<'_> {
247        async fn send(&mut self, buf: &[u8]) -> Result<(), PacketWriteError> {
248            self.writer.put_slice(buf)
249        }
250
251        async fn receive(&mut self, buf: &mut [u8]) -> Result<(), PacketReadError> {
252            let slice = self.reader.get_slice(buf.len())?;
253            buf.copy_from_slice(slice);
254            Ok(())
255        }
256
257        async fn receive_if_ready(&mut self, buf: &mut [u8]) -> Result<bool, PacketReadError> {
258            // Assume data is always ready, will error on underflow as required for tests
259            self.receive(buf).await?;
260            Ok(true)
261        }
262    }
263
264    async fn decode(data: &[u8], packet_generic: PacketGeneric<'_, 16, 16, 16>) {
265        let mut write_buf = [];
266        let connection = BufferConnection::new(data, &mut write_buf);
267
268        let mut buf = [0; 1024];
269        let mut client = PacketClient::new(connection, &mut buf);
270
271        let packet: Option<PacketGeneric<'_, 16, 16, 16>> =
272            client.receive_if_ready().await.unwrap();
273
274        assert_eq!(packet, Some(packet_generic));
275    }
276
277    async fn encode<P: Packet + write::Write>(packet: P, encoded: &[u8]) {
278        let read_buf = [];
279        let mut write_buf = [0; 1024];
280        let connection = BufferConnection::new(&read_buf, &mut write_buf);
281
282        let mut buf = [0; 1024];
283        let mut client = PacketClient::new(connection, &mut buf);
284
285        client.send(packet).await.unwrap();
286
287        assert_eq!(&write_buf[0..encoded.len()], encoded);
288    }
289
290    #[tokio::test]
291    async fn error_on_decode_connect_with_incorrect_length() {
292        let mut write_buf = [];
293        let connection =
294            BufferConnection::new(&ENCODED_CONNECT_INCORRECT_PACKET_LENGTH, &mut write_buf);
295
296        let mut buf = [0; 1024];
297        let mut client = PacketClient::new(connection, &mut buf);
298
299        let packet: Result<Option<PacketGeneric<'_, 16, 16, 16>>, PacketReadError> =
300            client.receive_if_ready().await;
301        assert_eq!(packet, Err(PacketReadError::IncorrectPacketLength));
302    }
303
304    #[tokio::test]
305    async fn decode_pingresp() {
306        decode(
307            &ENCODED_PINGRESP,
308            PacketGeneric::Pingreq(Pingreq::default()),
309        )
310        .await;
311    }
312
313    #[tokio::test]
314    async fn encode_pingresp() {
315        encode(Pingreq::default(), &ENCODED_PINGRESP).await;
316    }
317
318    #[tokio::test]
319    async fn decode_subscribe() {
320        let packet = example_subscribe_packet();
321        let packet_generic = PacketGeneric::Subscribe(packet);
322        decode(&ENCODED_SUBSCRIBE, packet_generic).await;
323    }
324
325    #[tokio::test]
326    async fn encode_subscribe() {
327        encode(example_subscribe_packet(), &ENCODED_SUBSCRIBE).await;
328    }
329
330    #[tokio::test]
331    async fn decode_connect() {
332        decode(
333            &ENCODED_CONNECT,
334            PacketGeneric::Connect(example_connect_packet()),
335        )
336        .await;
337    }
338
339    #[tokio::test]
340    async fn encode_connect() {
341        encode(example_connect_packet(), &ENCODED_CONNECT).await;
342    }
343
344    #[tokio::test]
345    async fn decode_fails_on_invalid_packet_type() {
346        let mut write_buf = [];
347        let connection = BufferConnection::new(&INVALID_PACKET_TYPE, &mut write_buf);
348
349        let mut buf = [0; 1024];
350        let mut client = PacketClient::new(connection, &mut buf);
351
352        assert_eq!(
353            client.receive_if_ready::<16, 16, 16>().await,
354            Err(PacketReadError::InvalidPacketType)
355        );
356    }
357
358    #[tokio::test]
359    async fn decode_fails_on_invalid_length_encoding() {
360        let mut write_buf = [];
361        let connection = BufferConnection::new(&INVALID_LENGTH, &mut write_buf);
362
363        let mut buf = [0; 1024];
364        let mut client = PacketClient::new(connection, &mut buf);
365
366        assert_eq!(
367            client.receive_if_ready::<16, 16, 16>().await,
368            Err(PacketReadError::InvalidVariableByteIntegerEncoding)
369        );
370    }
371
372    #[tokio::test]
373    async fn decode_fails_on_packet_length_bigger_than_buffer() {
374        let mut write_buf = [];
375        let connection = BufferConnection::new(&ENCODED_IMPLIES_PACKET_LENGTH_18, &mut write_buf);
376
377        let mut buf = [0; 17];
378        let mut client = PacketClient::new(connection, &mut buf);
379
380        assert_eq!(
381            client.receive_if_ready::<16, 16, 16>().await,
382            Err(PacketReadError::PacketTooLargeForBuffer)
383        );
384    }
385}