1use crate::{
2 codec::{
3 mqtt_reader::{MqttBufReader, MqttReader},
4 mqtt_writer::{MqttBufWriter, MqttWriter},
5 write,
6 },
7 data::packet_type::PacketType,
8 error::{PacketReadError, PacketWriteError},
9 packets::{packet::Packet, packet_generic::PacketGeneric},
10};
11
12#[allow(async_fn_in_trait)]
13pub trait Connection {
14 async fn send(&mut self, buf: &[u8]) -> Result<(), PacketWriteError>;
16
17 async fn receive(&mut self, buf: &mut [u8]) -> Result<(), PacketReadError>;
19
20 async fn receive_if_ready(&mut self, buf: &mut [u8]) -> Result<bool, PacketReadError>;
42}
43
44pub struct PacketClient<'a, C> {
45 connection: C,
46 buf: &'a mut [u8],
47}
48
49impl<'a, C> PacketClient<'a, C>
52where
53 C: Connection,
54{
55 pub fn new(connection: C, buf: &'a mut [u8]) -> Self {
56 Self { connection, buf }
57 }
58
59 pub async fn send<P>(&mut self, packet: P) -> Result<(), PacketWriteError>
60 where
61 P: Packet + write::Write,
62 {
63 let len = {
64 let mut r = MqttBufWriter::new(self.buf);
65 r.put(&packet)?;
66 r.position()
67 };
68 self.connection.send(&self.buf[0..len]).await
69 }
70
71 pub async fn receive<const P: usize, const W: usize, const S: usize>(
72 &mut self,
73 ) -> Result<PacketGeneric<'_, P, W, S>, PacketReadError> {
74 self.connection.receive(&mut self.buf[0..1]).await?;
76
77 self.receive_rest_of_packet().await
79 }
80
81 pub async fn receive_if_ready<const P: usize, const W: usize, const S: usize>(
82 &mut self,
83 ) -> Result<Option<PacketGeneric<'_, P, W, S>>, PacketReadError> {
84 let packet_started = self
87 .connection
88 .receive_if_ready(&mut self.buf[0..1])
89 .await?;
90
91 if !packet_started {
92 return Ok(None);
93 }
94
95 let packet = self.receive_rest_of_packet().await?;
97 Ok(Some(packet))
98 }
99
100 async fn receive_rest_of_packet<const P: usize, const W: usize, const S: usize>(
101 &mut self,
102 ) -> Result<PacketGeneric<'_, P, W, S>, PacketReadError> {
103 let mut position: usize = 1;
104
105 if !PacketType::is_valid_first_header_byte(self.buf[0]) {
108 return Err(PacketReadError::InvalidPacketType);
109 }
110
111 self.connection
114 .receive(&mut self.buf[position..position + 1])
115 .await?;
116 position += 1;
117
118 for _extra in 0..3 {
120 if self.buf[position - 1] & 128 == 0 {
121 break;
122 } else {
123 self.connection
124 .receive(&mut self.buf[position..position + 1])
125 .await?;
126 position += 1;
127 }
128 }
129
130 if self.buf[position - 1] & 128 != 0 {
132 return Err(PacketReadError::InvalidVariableByteIntegerEncoding);
133 }
134
135 let remaining_length = {
137 let mut r = MqttBufReader::new(&self.buf[1..position]);
138 r.get_variable_u32()?
139 } as usize;
140
141 if position + remaining_length > self.buf.len() {
143 return Err(PacketReadError::PacketTooLargeForBuffer);
144 }
145
146 self.connection
148 .receive(&mut self.buf[position..position + remaining_length])
149 .await?;
150 position += remaining_length;
151
152 let packet_buf = &mut self.buf[0..position];
154 let mut packet_reader = MqttBufReader::new(packet_buf);
155 let packet_generic = packet_reader.get()?;
156
157 Ok(packet_generic)
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::{
165 codec::mqtt_reader::MqttBufReader,
166 data::{
167 packet_identifier::PacketIdentifier,
168 property::{ConnectProperty, SubscribeProperty},
169 quality_of_service::QualityOfService,
170 },
171 packets::{
172 connect::Connect,
173 pingreq::Pingreq,
174 subscribe::{Subscribe, SubscriptionRequest},
175 },
176 };
177 use heapless::Vec;
178
179 const ENCODED_PINGRESP: [u8; 2] = [0xC0, 0x00];
180 const INVALID_PACKET_TYPE: [u8; 2] = [0xC1, 0x00];
181 const INVALID_LENGTH: [u8; 5] = [0xC0, 0x80, 0x80, 0x80, 0x80];
182 const ENCODED_IMPLIES_PACKET_LENGTH_18: [u8; 2] = [0xC0, 0x10];
184
185 const ENCODED_SUBSCRIBE: [u8; 30] = [
186 0x82, 0x1C, 0x15, 0x38, 0x03, 0x0B, 0x80, 0x13, 0x00, 0x0A, 0x74, 0x65, 0x73, 0x74, 0x2f,
187 0x74, 0x6f, 0x70, 0x69, 0x63, 0x00, 0x00, 0x06, 0x68, 0x65, 0x68, 0x65, 0x2F, 0x23, 0x01,
188 ];
189
190 const ENCODED_CONNECT: [u8; 18] = [
191 0x10, 0x10, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x02, 0x00, 0x3c, 0x03, 0x21, 0x00,
192 0x14, 0x00, 0x00,
193 ];
194
195 const ENCODED_CONNECT_INCORRECT_PACKET_LENGTH: [u8; 19] = [
200 0x10, 0x11, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x02, 0x00, 0x3c, 0x03, 0x21, 0x00,
201 0x14, 0x00, 0x00, 0x00,
202 ];
203
204 fn example_subscribe_packet<'a>() -> Subscribe<'a, 16, 16> {
205 let primary_request = SubscriptionRequest::new("test/topic", QualityOfService::Qos0);
206 let mut additional_requests = Vec::new();
207 additional_requests
208 .push(SubscriptionRequest::new("hehe/#", QualityOfService::Qos1))
209 .unwrap();
210 let mut properties = Vec::new();
211 properties
212 .push(SubscribeProperty::SubscriptionIdentifier(2432.into()))
213 .unwrap();
214 let packet = Subscribe::new(
215 PacketIdentifier(5432),
216 primary_request,
217 additional_requests,
218 properties,
219 );
220
221 packet
222 }
223
224 fn example_connect_packet<'a>() -> Connect<'a, 16, 16> {
225 let mut packet = Connect::new(60, None, None, "", true, None, Vec::new());
226 packet
227 .properties
228 .push(ConnectProperty::ReceiveMaximum(20.into()))
229 .unwrap();
230 packet
231 }
232
233 struct BufferConnection<'a> {
234 reader: MqttBufReader<'a>,
235 writer: MqttBufWriter<'a>,
236 }
237
238 impl<'a> BufferConnection<'a> {
239 pub fn new(read_buf: &'a [u8], write_buf: &'a mut [u8]) -> Self {
240 let reader = MqttBufReader::new(read_buf);
241 let writer = MqttBufWriter::new(write_buf);
242 BufferConnection { reader, writer }
243 }
244 }
245
246 impl Connection for BufferConnection<'_> {
247 async fn send(&mut self, buf: &[u8]) -> Result<(), PacketWriteError> {
248 self.writer.put_slice(buf)
249 }
250
251 async fn receive(&mut self, buf: &mut [u8]) -> Result<(), PacketReadError> {
252 let slice = self.reader.get_slice(buf.len())?;
253 buf.copy_from_slice(slice);
254 Ok(())
255 }
256
257 async fn receive_if_ready(&mut self, buf: &mut [u8]) -> Result<bool, PacketReadError> {
258 self.receive(buf).await?;
260 Ok(true)
261 }
262 }
263
264 async fn decode(data: &[u8], packet_generic: PacketGeneric<'_, 16, 16, 16>) {
265 let mut write_buf = [];
266 let connection = BufferConnection::new(data, &mut write_buf);
267
268 let mut buf = [0; 1024];
269 let mut client = PacketClient::new(connection, &mut buf);
270
271 let packet: Option<PacketGeneric<'_, 16, 16, 16>> =
272 client.receive_if_ready().await.unwrap();
273
274 assert_eq!(packet, Some(packet_generic));
275 }
276
277 async fn encode<P: Packet + write::Write>(packet: P, encoded: &[u8]) {
278 let read_buf = [];
279 let mut write_buf = [0; 1024];
280 let connection = BufferConnection::new(&read_buf, &mut write_buf);
281
282 let mut buf = [0; 1024];
283 let mut client = PacketClient::new(connection, &mut buf);
284
285 client.send(packet).await.unwrap();
286
287 assert_eq!(&write_buf[0..encoded.len()], encoded);
288 }
289
290 #[tokio::test]
291 async fn error_on_decode_connect_with_incorrect_length() {
292 let mut write_buf = [];
293 let connection =
294 BufferConnection::new(&ENCODED_CONNECT_INCORRECT_PACKET_LENGTH, &mut write_buf);
295
296 let mut buf = [0; 1024];
297 let mut client = PacketClient::new(connection, &mut buf);
298
299 let packet: Result<Option<PacketGeneric<'_, 16, 16, 16>>, PacketReadError> =
300 client.receive_if_ready().await;
301 assert_eq!(packet, Err(PacketReadError::IncorrectPacketLength));
302 }
303
304 #[tokio::test]
305 async fn decode_pingresp() {
306 decode(
307 &ENCODED_PINGRESP,
308 PacketGeneric::Pingreq(Pingreq::default()),
309 )
310 .await;
311 }
312
313 #[tokio::test]
314 async fn encode_pingresp() {
315 encode(Pingreq::default(), &ENCODED_PINGRESP).await;
316 }
317
318 #[tokio::test]
319 async fn decode_subscribe() {
320 let packet = example_subscribe_packet();
321 let packet_generic = PacketGeneric::Subscribe(packet);
322 decode(&ENCODED_SUBSCRIBE, packet_generic).await;
323 }
324
325 #[tokio::test]
326 async fn encode_subscribe() {
327 encode(example_subscribe_packet(), &ENCODED_SUBSCRIBE).await;
328 }
329
330 #[tokio::test]
331 async fn decode_connect() {
332 decode(
333 &ENCODED_CONNECT,
334 PacketGeneric::Connect(example_connect_packet()),
335 )
336 .await;
337 }
338
339 #[tokio::test]
340 async fn encode_connect() {
341 encode(example_connect_packet(), &ENCODED_CONNECT).await;
342 }
343
344 #[tokio::test]
345 async fn decode_fails_on_invalid_packet_type() {
346 let mut write_buf = [];
347 let connection = BufferConnection::new(&INVALID_PACKET_TYPE, &mut write_buf);
348
349 let mut buf = [0; 1024];
350 let mut client = PacketClient::new(connection, &mut buf);
351
352 assert_eq!(
353 client.receive_if_ready::<16, 16, 16>().await,
354 Err(PacketReadError::InvalidPacketType)
355 );
356 }
357
358 #[tokio::test]
359 async fn decode_fails_on_invalid_length_encoding() {
360 let mut write_buf = [];
361 let connection = BufferConnection::new(&INVALID_LENGTH, &mut write_buf);
362
363 let mut buf = [0; 1024];
364 let mut client = PacketClient::new(connection, &mut buf);
365
366 assert_eq!(
367 client.receive_if_ready::<16, 16, 16>().await,
368 Err(PacketReadError::InvalidVariableByteIntegerEncoding)
369 );
370 }
371
372 #[tokio::test]
373 async fn decode_fails_on_packet_length_bigger_than_buffer() {
374 let mut write_buf = [];
375 let connection = BufferConnection::new(&ENCODED_IMPLIES_PACKET_LENGTH_18, &mut write_buf);
376
377 let mut buf = [0; 17];
378 let mut client = PacketClient::new(connection, &mut buf);
379
380 assert_eq!(
381 client.receive_if_ready::<16, 16, 16>().await,
382 Err(PacketReadError::PacketTooLargeForBuffer)
383 );
384 }
385}