rumq_core/mqtt4/
asyncserialize.rs

1use crate::mqtt4::*;
2
3use async_trait::async_trait;
4use tokio::io::AsyncWriteExt;
5
6/// Mqtt awareness on top of tokio's `AsyncWrite`
7#[async_trait]
8pub trait AsyncMqttWrite: AsyncWriteExt + Unpin {
9    async fn async_mqtt_write(&mut self, packet: &Packet) -> Result<(), Error> {
10        match packet {
11            Packet::Connect(connect) => {
12                self.write_u8(0b00010000).await?;
13                self.write_remaining_length(connect.len()).await?;
14                self.write_mqtt_string("MQTT").await?;
15                self.write_u8(0x04).await?;
16
17                let mut connect_flags = 0;
18                if connect.clean_session {
19                    connect_flags |= 0x02;
20                }
21
22                match &connect.last_will {
23                    Some(w) if w.retain => connect_flags |= 0x04 | (w.qos as u8) << 3 | 0x20,
24                    Some(w) => connect_flags |= 0x04 | (w.qos as u8) << 3,
25                    None => ()
26                }
27
28                if let Some(_) = connect.password {
29                    connect_flags |= 0x40;
30                }
31                if let Some(_) = connect.username {
32                    connect_flags |= 0x80;
33                }
34
35                self.write_u8(connect_flags).await?;
36                self.write_u16(connect.keep_alive).await?;
37                self.write_mqtt_string(connect.client_id.as_ref()).await?;
38
39                if let Some(ref last_will) = connect.last_will {
40                    self.write_mqtt_string(last_will.topic.as_ref()).await?;
41                    self.write_mqtt_string(last_will.message.as_ref()).await?;
42                }
43                if let Some(ref username) = connect.username {
44                    self.write_mqtt_string(username).await?;
45                }
46                if let Some(ref password) = connect.password {
47                    self.write_mqtt_string(password).await?;
48                }
49                Ok(())
50            }
51            Packet::Connack(connack) => {
52                let session_present = connack.session_present as u8;
53                let code = connack.code as u8;
54                let data = [0x20, 0x02, session_present, code];
55                self.write_all(&data).await?;
56                Ok(())
57            }
58            Packet::Publish(publish) => {
59                self.write_u8(0b00110000 | publish.retain as u8 | ((publish.qos as u8) << 1) | ((publish.dup as u8) << 3)).await?;
60                let mut len = publish.topic_name.len() + 2 + publish.payload.len();
61
62                if publish.qos != QoS::AtMostOnce && None != publish.pkid {
63                    len += 2;
64                }
65
66                self.write_remaining_length(len).await?;
67                self.write_mqtt_string(publish.topic_name.as_str()).await?;
68                if publish.qos != QoS::AtMostOnce {
69                    if let Some(pkid) = publish.pkid {
70                        self.write_u16(pkid.0).await?;
71                    }
72                }
73                self.write_all(&publish.payload.as_ref()).await?;
74                Ok(())
75            }
76            Packet::Puback(pkid) => {
77                self.write_all(&[0x40, 0x02]).await?;
78                self.write_u16(pkid.0).await?;
79                Ok(())
80            }
81            Packet::Pubrec(pkid) => {
82                self.write_all(&[0x50, 0x02]).await?;
83                self.write_u16(pkid.0).await?;
84                Ok(())
85            }
86            Packet::Pubrel(pkid) => {
87                self.write_all(&[0x62, 0x02]).await?;
88                self.write_u16(pkid.0).await?;
89                Ok(())
90            }
91            Packet::Pubcomp(pkid) => {
92                self.write_all(&[0x70, 0x02]).await?;
93                self.write_u16(pkid.0).await?;
94                Ok(())
95            }
96            Packet::Subscribe(subscribe) => {
97                self.write_all(&[0x82]).await?;
98                let len = 2 + subscribe.topics.iter().fold(0, |s, ref t| s + t.topic_path.len() + 3);
99                
100                self.write_remaining_length(len).await?;
101                self.write_u16(subscribe.pkid.0).await?;
102                for topic in subscribe.topics.as_ref() as &Vec<SubscribeTopic> {
103                    self.write_mqtt_string(topic.topic_path.as_str()).await?;
104                    self.write_u8(topic.qos as u8).await?;
105                }
106                Ok(())
107            }
108            Packet::Suback(suback) => {
109                self.write_all(&[0x90]).await?;
110                self.write_remaining_length(suback.return_codes.len() + 2).await?;
111                self.write_u16(suback.pkid.0).await?;
112                
113                let payload: Vec<u8> = suback.return_codes.iter().map(|&code| match code {
114                    SubscribeReturnCodes::Success(qos) => qos as u8,
115                    SubscribeReturnCodes::Failure => 0x80,
116                }).collect();
117                
118                self.write_all(&payload).await?;
119                Ok(())
120            }
121            Packet::Unsubscribe(unsubscribe) => {
122                self.write_all(&[0xA2]).await?;
123                let len = 2 + unsubscribe.topics.iter().fold(0, |s, ref topic| s + topic.len() + 2);
124                self.write_remaining_length(len).await?;
125                self.write_u16(unsubscribe.pkid.0).await?;
126                
127                for topic in unsubscribe.topics.as_ref() as &Vec<String> {
128                    self.write_mqtt_string(topic.as_str()).await?;
129                }
130                Ok(())
131            }
132            Packet::Unsuback(pkid) => {
133                self.write_all(&[0xB0, 0x02]).await?;
134                self.write_u16(pkid.0).await?;
135                Ok(())
136            }
137            Packet::Pingreq => {
138                self.write_all(&[0xc0, 0]).await?;
139                Ok(())
140            }
141            Packet::Pingresp => {
142                self.write_all(&[0xd0, 0]).await?;
143                Ok(())
144            }
145            Packet::Disconnect => {
146                self.write_all(&[0xe0, 0]).await?;
147                Ok(())
148            }
149        }
150    }
151
152    async fn write_mqtt_string(&mut self, string: &str) -> Result<(), Error> {
153        self.write_u16(string.len() as u16).await?;
154        self.write_all(string.as_bytes()).await?;
155        Ok(())
156    }
157
158    async fn write_remaining_length(&mut self, len: usize) -> Result<(), Error> {
159        if len > 268_435_455 {
160            return Err(Error::PayloadTooLong);
161        }
162
163        let mut done = false;
164        let mut x = len;
165
166        while !done {
167            let mut byte = (x % 128) as u8;
168            x = x / 128;
169            if x > 0 {
170                byte = byte | 128;
171            }
172            self.write_u8(byte).await?;
173            done = x <= 0;
174        }
175        Ok(())
176    }
177}
178
179/// Implement MqttWrite for every AsyncWriteExt type (and hence AsyncWrite type)
180impl<W: AsyncWriteExt + ?Sized + Unpin> AsyncMqttWrite for W {}
181
182#[cfg(test)]
183mod test {
184    use super::AsyncMqttWrite;
185    use super::{Connack, Connect, Packet, Publish, Subscribe};
186    use super::{ConnectReturnCode, LastWill, PacketIdentifier, Protocol, QoS, SubscribeTopic};
187
188    #[tokio::test]
189    async fn write_packet_connect_mqtt_protocol_works() {
190        let connect = Packet::Connect(Connect {
191            protocol: Protocol::MQTT(4),
192            keep_alive: 10,
193            client_id: "test".to_owned(),
194            clean_session: true,
195            last_will: Some(LastWill {
196                topic: "/a".to_owned(),
197                message: "offline".to_owned(),
198                retain: false,
199                qos: QoS::AtLeastOnce,
200            }),
201            username: Some("rust".to_owned()),
202            password: Some("mq".to_owned()),
203        });
204
205        let mut stream = Vec::new();
206        stream.async_mqtt_write(&connect).await.unwrap();
207
208        assert_eq!(
209            stream.clone(),
210            vec![
211                0x10, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04,
212                0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session
213                0x00, 0x0a, // 10 sec
214                0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id
215                0x00, 0x02, '/' as u8, 'a' as u8, // will topic = '/a'
216                0x00, 0x07, 'o' as u8, 'f' as u8, 'f' as u8, 'l' as u8, 'i' as u8, 'n' as u8, 'e' as u8, // will msg = 'offline'
217                0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust'
218                0x00, 0x02, 'm' as u8, 'q' as u8 // password = 'mq'
219            ]
220        );
221    }
222
223    #[tokio::test]
224    async fn write_packet_connack_works() {
225        let connack = Packet::Connack(Connack {
226            session_present: true,
227            code: ConnectReturnCode::Accepted,
228        });
229
230        let mut stream = Vec::new();
231        stream.async_mqtt_write(&connack).await.unwrap();
232
233        assert_eq!(stream, vec![0b00100000, 0x02, 0x01, 0x00]);
234    }
235
236    #[tokio::test]
237    async fn write_packet_publish_at_least_once_works() {
238        let publish = Packet::Publish(Publish {
239            dup: false,
240            qos: QoS::AtLeastOnce,
241            retain: false,
242            topic_name: "a/b".to_owned(),
243            pkid: Some(PacketIdentifier(10)),
244            payload: vec![0xF1, 0xF2, 0xF3, 0xF4]
245        });
246
247        let mut stream = Vec::new();
248        stream.async_mqtt_write(&publish).await.unwrap();
249
250        assert_eq!(
251            stream,
252            vec![0b00110010, 11, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 0x00, 0x0a, 0xF1, 0xF2, 0xF3, 0xF4]
253        );
254    }
255
256    #[tokio::test]
257    async fn write_packet_publish_at_most_once_works() {
258        let publish = Packet::Publish(Publish {
259            dup: false,
260            qos: QoS::AtMostOnce,
261            retain: false,
262            topic_name: "a/b".to_owned(),
263            pkid: None,
264            payload: vec![0xE1, 0xE2, 0xE3, 0xE4],
265        });
266
267        let mut stream = Vec::new();
268        stream.async_mqtt_write(&publish).await.unwrap();
269
270        assert_eq!(
271            stream,
272            vec![0b00110000, 9, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 0xE1, 0xE2, 0xE3, 0xE4]
273        );
274    }
275
276    #[tokio::test]
277    async fn write_packet_subscribe_works() {
278        let subscribe = Packet::Subscribe(Subscribe {
279            pkid: PacketIdentifier(260),
280            topics: vec![
281                SubscribeTopic {
282                    topic_path: "a/+".to_owned(),
283                    qos: QoS::AtMostOnce,
284                },
285                SubscribeTopic {
286                    topic_path: "#".to_owned(),
287                    qos: QoS::AtLeastOnce,
288                },
289                SubscribeTopic {
290                    topic_path: "a/b/c".to_owned(),
291                    qos: QoS::ExactlyOnce,
292                },
293            ],
294        });
295
296        let mut stream = Vec::new();
297        stream.async_mqtt_write(&subscribe).await.unwrap();
298
299        assert_eq!(
300            stream,
301            vec![
302                0b10000010, 20, 0x01, 0x04, // pkid = 260
303                0x00, 0x03, 'a' as u8, '/' as u8, '+' as u8, // topic filter = 'a/+'
304                0x00,      // qos = 0
305                0x00, 0x01, '#' as u8, // topic filter = '#'
306                0x01,      // qos = 1
307                0x00, 0x05, 'a' as u8, '/' as u8, 'b' as u8, '/' as u8, 'c' as u8, // topic filter = 'a/b/c'
308                0x02       // qos = 2
309            ]
310        );
311    }
312}