1use crate::mqtt4::*;
2
3use async_trait::async_trait;
4use tokio::io::AsyncWriteExt;
5
6#[async_trait]
8pub trait AsyncMqttWrite: AsyncWriteExt + Unpin {
9 async fn async_mqtt_write(&mut self, packet: &Packet) -> Result<(), Error> {
10 match packet {
11 Packet::Connect(connect) => {
12 self.write_u8(0b00010000).await?;
13 self.write_remaining_length(connect.len()).await?;
14 self.write_mqtt_string("MQTT").await?;
15 self.write_u8(0x04).await?;
16
17 let mut connect_flags = 0;
18 if connect.clean_session {
19 connect_flags |= 0x02;
20 }
21
22 match &connect.last_will {
23 Some(w) if w.retain => connect_flags |= 0x04 | (w.qos as u8) << 3 | 0x20,
24 Some(w) => connect_flags |= 0x04 | (w.qos as u8) << 3,
25 None => ()
26 }
27
28 if let Some(_) = connect.password {
29 connect_flags |= 0x40;
30 }
31 if let Some(_) = connect.username {
32 connect_flags |= 0x80;
33 }
34
35 self.write_u8(connect_flags).await?;
36 self.write_u16(connect.keep_alive).await?;
37 self.write_mqtt_string(connect.client_id.as_ref()).await?;
38
39 if let Some(ref last_will) = connect.last_will {
40 self.write_mqtt_string(last_will.topic.as_ref()).await?;
41 self.write_mqtt_string(last_will.message.as_ref()).await?;
42 }
43 if let Some(ref username) = connect.username {
44 self.write_mqtt_string(username).await?;
45 }
46 if let Some(ref password) = connect.password {
47 self.write_mqtt_string(password).await?;
48 }
49 Ok(())
50 }
51 Packet::Connack(connack) => {
52 let session_present = connack.session_present as u8;
53 let code = connack.code as u8;
54 let data = [0x20, 0x02, session_present, code];
55 self.write_all(&data).await?;
56 Ok(())
57 }
58 Packet::Publish(publish) => {
59 self.write_u8(0b00110000 | publish.retain as u8 | ((publish.qos as u8) << 1) | ((publish.dup as u8) << 3)).await?;
60 let mut len = publish.topic_name.len() + 2 + publish.payload.len();
61
62 if publish.qos != QoS::AtMostOnce && None != publish.pkid {
63 len += 2;
64 }
65
66 self.write_remaining_length(len).await?;
67 self.write_mqtt_string(publish.topic_name.as_str()).await?;
68 if publish.qos != QoS::AtMostOnce {
69 if let Some(pkid) = publish.pkid {
70 self.write_u16(pkid.0).await?;
71 }
72 }
73 self.write_all(&publish.payload.as_ref()).await?;
74 Ok(())
75 }
76 Packet::Puback(pkid) => {
77 self.write_all(&[0x40, 0x02]).await?;
78 self.write_u16(pkid.0).await?;
79 Ok(())
80 }
81 Packet::Pubrec(pkid) => {
82 self.write_all(&[0x50, 0x02]).await?;
83 self.write_u16(pkid.0).await?;
84 Ok(())
85 }
86 Packet::Pubrel(pkid) => {
87 self.write_all(&[0x62, 0x02]).await?;
88 self.write_u16(pkid.0).await?;
89 Ok(())
90 }
91 Packet::Pubcomp(pkid) => {
92 self.write_all(&[0x70, 0x02]).await?;
93 self.write_u16(pkid.0).await?;
94 Ok(())
95 }
96 Packet::Subscribe(subscribe) => {
97 self.write_all(&[0x82]).await?;
98 let len = 2 + subscribe.topics.iter().fold(0, |s, ref t| s + t.topic_path.len() + 3);
99
100 self.write_remaining_length(len).await?;
101 self.write_u16(subscribe.pkid.0).await?;
102 for topic in subscribe.topics.as_ref() as &Vec<SubscribeTopic> {
103 self.write_mqtt_string(topic.topic_path.as_str()).await?;
104 self.write_u8(topic.qos as u8).await?;
105 }
106 Ok(())
107 }
108 Packet::Suback(suback) => {
109 self.write_all(&[0x90]).await?;
110 self.write_remaining_length(suback.return_codes.len() + 2).await?;
111 self.write_u16(suback.pkid.0).await?;
112
113 let payload: Vec<u8> = suback.return_codes.iter().map(|&code| match code {
114 SubscribeReturnCodes::Success(qos) => qos as u8,
115 SubscribeReturnCodes::Failure => 0x80,
116 }).collect();
117
118 self.write_all(&payload).await?;
119 Ok(())
120 }
121 Packet::Unsubscribe(unsubscribe) => {
122 self.write_all(&[0xA2]).await?;
123 let len = 2 + unsubscribe.topics.iter().fold(0, |s, ref topic| s + topic.len() + 2);
124 self.write_remaining_length(len).await?;
125 self.write_u16(unsubscribe.pkid.0).await?;
126
127 for topic in unsubscribe.topics.as_ref() as &Vec<String> {
128 self.write_mqtt_string(topic.as_str()).await?;
129 }
130 Ok(())
131 }
132 Packet::Unsuback(pkid) => {
133 self.write_all(&[0xB0, 0x02]).await?;
134 self.write_u16(pkid.0).await?;
135 Ok(())
136 }
137 Packet::Pingreq => {
138 self.write_all(&[0xc0, 0]).await?;
139 Ok(())
140 }
141 Packet::Pingresp => {
142 self.write_all(&[0xd0, 0]).await?;
143 Ok(())
144 }
145 Packet::Disconnect => {
146 self.write_all(&[0xe0, 0]).await?;
147 Ok(())
148 }
149 }
150 }
151
152 async fn write_mqtt_string(&mut self, string: &str) -> Result<(), Error> {
153 self.write_u16(string.len() as u16).await?;
154 self.write_all(string.as_bytes()).await?;
155 Ok(())
156 }
157
158 async fn write_remaining_length(&mut self, len: usize) -> Result<(), Error> {
159 if len > 268_435_455 {
160 return Err(Error::PayloadTooLong);
161 }
162
163 let mut done = false;
164 let mut x = len;
165
166 while !done {
167 let mut byte = (x % 128) as u8;
168 x = x / 128;
169 if x > 0 {
170 byte = byte | 128;
171 }
172 self.write_u8(byte).await?;
173 done = x <= 0;
174 }
175 Ok(())
176 }
177}
178
179impl<W: AsyncWriteExt + ?Sized + Unpin> AsyncMqttWrite for W {}
181
182#[cfg(test)]
183mod test {
184 use super::AsyncMqttWrite;
185 use super::{Connack, Connect, Packet, Publish, Subscribe};
186 use super::{ConnectReturnCode, LastWill, PacketIdentifier, Protocol, QoS, SubscribeTopic};
187
188 #[tokio::test]
189 async fn write_packet_connect_mqtt_protocol_works() {
190 let connect = Packet::Connect(Connect {
191 protocol: Protocol::MQTT(4),
192 keep_alive: 10,
193 client_id: "test".to_owned(),
194 clean_session: true,
195 last_will: Some(LastWill {
196 topic: "/a".to_owned(),
197 message: "offline".to_owned(),
198 retain: false,
199 qos: QoS::AtLeastOnce,
200 }),
201 username: Some("rust".to_owned()),
202 password: Some("mq".to_owned()),
203 });
204
205 let mut stream = Vec::new();
206 stream.async_mqtt_write(&connect).await.unwrap();
207
208 assert_eq!(
209 stream.clone(),
210 vec![
211 0x10, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04,
212 0b11001110, 0x00, 0x0a, 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, 0x00, 0x02, '/' as u8, 'a' as u8, 0x00, 0x07, 'o' as u8, 'f' as u8, 'f' as u8, 'l' as u8, 'i' as u8, 'n' as u8, 'e' as u8, 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, 0x00, 0x02, 'm' as u8, 'q' as u8 ]
220 );
221 }
222
223 #[tokio::test]
224 async fn write_packet_connack_works() {
225 let connack = Packet::Connack(Connack {
226 session_present: true,
227 code: ConnectReturnCode::Accepted,
228 });
229
230 let mut stream = Vec::new();
231 stream.async_mqtt_write(&connack).await.unwrap();
232
233 assert_eq!(stream, vec![0b00100000, 0x02, 0x01, 0x00]);
234 }
235
236 #[tokio::test]
237 async fn write_packet_publish_at_least_once_works() {
238 let publish = Packet::Publish(Publish {
239 dup: false,
240 qos: QoS::AtLeastOnce,
241 retain: false,
242 topic_name: "a/b".to_owned(),
243 pkid: Some(PacketIdentifier(10)),
244 payload: vec![0xF1, 0xF2, 0xF3, 0xF4]
245 });
246
247 let mut stream = Vec::new();
248 stream.async_mqtt_write(&publish).await.unwrap();
249
250 assert_eq!(
251 stream,
252 vec![0b00110010, 11, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 0x00, 0x0a, 0xF1, 0xF2, 0xF3, 0xF4]
253 );
254 }
255
256 #[tokio::test]
257 async fn write_packet_publish_at_most_once_works() {
258 let publish = Packet::Publish(Publish {
259 dup: false,
260 qos: QoS::AtMostOnce,
261 retain: false,
262 topic_name: "a/b".to_owned(),
263 pkid: None,
264 payload: vec![0xE1, 0xE2, 0xE3, 0xE4],
265 });
266
267 let mut stream = Vec::new();
268 stream.async_mqtt_write(&publish).await.unwrap();
269
270 assert_eq!(
271 stream,
272 vec![0b00110000, 9, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 0xE1, 0xE2, 0xE3, 0xE4]
273 );
274 }
275
276 #[tokio::test]
277 async fn write_packet_subscribe_works() {
278 let subscribe = Packet::Subscribe(Subscribe {
279 pkid: PacketIdentifier(260),
280 topics: vec![
281 SubscribeTopic {
282 topic_path: "a/+".to_owned(),
283 qos: QoS::AtMostOnce,
284 },
285 SubscribeTopic {
286 topic_path: "#".to_owned(),
287 qos: QoS::AtLeastOnce,
288 },
289 SubscribeTopic {
290 topic_path: "a/b/c".to_owned(),
291 qos: QoS::ExactlyOnce,
292 },
293 ],
294 });
295
296 let mut stream = Vec::new();
297 stream.async_mqtt_write(&subscribe).await.unwrap();
298
299 assert_eq!(
300 stream,
301 vec![
302 0b10000010, 20, 0x01, 0x04, 0x00, 0x03, 'a' as u8, '/' as u8, '+' as u8, 0x00, 0x00, 0x01, '#' as u8, 0x01, 0x00, 0x05, 'a' as u8, '/' as u8, 'b' as u8, '/' as u8, 'c' as u8, 0x02 ]
310 );
311 }
312}