RustMqtt/
write.rs

1use crate::{Error, Packet, QoS, Result, SubscribeReturnCodes, SubscribeTopic, MAX_PAYLOAD_SIZE};
2use byteorder::{BigEndian, WriteBytesExt};
3use std::io::{BufWriter, Cursor, Write};
4use std::net::TcpStream;
5
6pub trait MqttWrite: WriteBytesExt {
7    fn write_packet(&mut self, packet: &Packet) -> Result<()> {
8        match packet {
9            &Packet::Connect(ref connect) => {
10                self.write_u8(0b00010000)?;
11                let prot_name = connect.protocol.name();
12                let mut len = 8 + prot_name.len() + connect.client_id.len();
13                if let Some(ref last_will) = connect.last_will {
14                    len += 4 + last_will.topic.len() + last_will.message.len();
15                }
16                if let Some(ref username) = connect.username {
17                    len += 2 + username.len();
18                }
19                if let Some(ref password) = connect.password {
20                    len += 2 + password.len();
21                }
22                self.write_remaining_length(len)?;
23                self.write_mqtt_string(prot_name)?;
24                self.write_u8(connect.protocol.level())?;
25                let mut connect_flags = 0;
26                if connect.clean_session {
27                    connect_flags |= 0x02;
28                }
29                if let Some(ref last_will) = connect.last_will {
30                    connect_flags |= 0x04;
31                    connect_flags |= last_will.qos.to_u8() << 3;
32                    if last_will.retain {
33                        connect_flags |= 0x20;
34                    }
35                }
36                if let Some(_) = connect.password {
37                    connect_flags |= 0x40;
38                }
39                if let Some(_) = connect.username {
40                    connect_flags |= 0x80;
41                }
42                self.write_u8(connect_flags)?;
43                self.write_u16::<BigEndian>(connect.keep_alive)?;
44                self.write_mqtt_string(connect.client_id.as_ref())?;
45                if let Some(ref last_will) = connect.last_will {
46                    self.write_mqtt_string(last_will.topic.as_ref())?;
47                    self.write_mqtt_string(last_will.message.as_ref())?;
48                }
49                if let Some(ref username) = connect.username {
50                    self.write_mqtt_string(username)?;
51                }
52                if let Some(ref password) = connect.password {
53                    self.write_mqtt_string(password)?;
54                }
55                Ok(())
56            }
57            &Packet::Connack(ref connack) => {
58                self.write(&[
59                    0x20,
60                    0x02,
61                    connack.session_present as u8,
62                    connack.code.to_u8(),
63                ])?;
64                Ok(())
65            }
66            &Packet::Publish(ref publish) => {
67                self.write_u8(
68                    0b00110000
69                        | publish.retain as u8
70                        | (publish.qos.to_u8() << 1)
71                        | ((publish.dup as u8) << 3),
72                )?;
73                let mut len = publish.topic_name.len() + 2 + publish.payload.len();
74                if publish.qos != QoS::AtMostOnce && None != publish.pid {
75                    len += 2;
76                }
77                self.write_remaining_length(len)?;
78                self.write_mqtt_string(publish.topic_name.as_str())?;
79                if publish.qos != QoS::AtMostOnce {
80                    if let Some(pid) = publish.pid {
81                        self.write_u16::<BigEndian>(pid.0)?;
82                    }
83                }
84                self.write(&publish.payload.as_ref())?;
85                Ok(())
86            }
87            &Packet::Puback(ref pid) => {
88                self.write(&[0x40, 0x02])?;
89                self.write_u16::<BigEndian>(pid.0)?;
90                Ok(())
91            }
92            &Packet::Pubrec(ref pid) => {
93                self.write(&[0x50, 0x02])?;
94                self.write_u16::<BigEndian>(pid.0)?;
95                Ok(())
96            }
97            &Packet::Pubrel(ref pid) => {
98                self.write(&[0x62, 0x02])?;
99                self.write_u16::<BigEndian>(pid.0)?;
100                Ok(())
101            }
102            &Packet::Pubcomp(ref pid) => {
103                self.write(&[0x70, 0x02])?;
104                self.write_u16::<BigEndian>(pid.0)?;
105                Ok(())
106            }
107            &Packet::Subscribe(ref subscribe) => {
108                self.write(&[0x82])?;
109                let len = 2 + subscribe
110                    .topics
111                    .iter()
112                    .fold(0, |s, ref t| s + t.topic_path.len() + 3);
113                self.write_remaining_length(len)?;
114                self.write_u16::<BigEndian>(subscribe.pid.0)?;
115                for topic in subscribe.topics.as_ref() as &Vec<SubscribeTopic> {
116                    self.write_mqtt_string(topic.topic_path.as_str())?;
117                    self.write_u8(topic.qos.to_u8())?;
118                }
119                Ok(())
120            }
121            &Packet::Suback(ref suback) => {
122                self.write(&[0x90])?;
123                self.write_remaining_length(suback.return_codes.len() + 2)?;
124                self.write_u16::<BigEndian>(suback.pid.0)?;
125                let payload: Vec<u8> = suback
126                    .return_codes
127                    .iter()
128                    .map({
129                        |&code| match code {
130                            SubscribeReturnCodes::Success(qos) => qos.to_u8(),
131                            SubscribeReturnCodes::Failure => 0x80,
132                        }
133                    })
134                    .collect();
135                self.write(&payload)?;
136                Ok(())
137            }
138            &Packet::Unsubscribe(ref unsubscribe) => {
139                self.write(&[0xA2])?;
140                let len = 2 + unsubscribe
141                    .topics
142                    .iter()
143                    .fold(0, |s, ref topic| s + topic.len() + 2);
144                self.write_remaining_length(len)?;
145                self.write_u16::<BigEndian>(unsubscribe.pid.0)?;
146                for topic in unsubscribe.topics.as_ref() as &Vec<String> {
147                    self.write_mqtt_string(topic.as_str())?;
148                }
149                Ok(())
150            }
151            &Packet::Unsuback(ref pid) => {
152                self.write(&[0xB0, 0x02])?;
153                self.write_u16::<BigEndian>(pid.0)?;
154                Ok(())
155            }
156            &Packet::Pingreq => {
157                self.write(&[0xc0, 0])?;
158                Ok(())
159            }
160            &Packet::Pingresp => {
161                self.write(&[0xd0, 0])?;
162                Ok(())
163            }
164            &Packet::Disconnect => {
165                self.write(&[0xe0, 0])?;
166                Ok(())
167            }
168        }
169    }
170
171    fn write_mqtt_string(&mut self, string: &str) -> Result<()> {
172        self.write_u16::<BigEndian>(string.len() as u16)?;
173        self.write(string.as_bytes())?;
174        Ok(())
175    }
176
177    fn write_remaining_length(&mut self, len: usize) -> Result<()> {
178        if len > MAX_PAYLOAD_SIZE {
179            return Err(Error::PayloadTooLong);
180        }
181
182        let mut done = false;
183        let mut x = len;
184
185        while !done {
186            let mut byte = (x % 128) as u8;
187            x = x / 128;
188            if x > 0 {
189                byte = byte | 128;
190            }
191            self.write_u8(byte)?;
192            done = x <= 0;
193        }
194
195        Ok(())
196    }
197}
198
199impl MqttWrite for TcpStream {}
200impl MqttWrite for Cursor<Vec<u8>> {}
201impl<T: Write> MqttWrite for BufWriter<T> {}
202
203#[cfg(test)]
204mod test {
205    use super::super::mqtt::{Connack, Connect, Packet, Publish, Subscribe};
206    use super::super::{
207        ConnectReturnCode, LastWill, PacketIdentifier, Protocol, QoS, SubscribeTopic,
208    };
209    use super::MqttWrite;
210    use std::io::Cursor;
211    use std::sync::Arc;
212
213    #[test]
214    fn write_packet_connect_mqtt_protocol_test() {
215        let connect = Packet::Connect(Box::new(Connect {
216            protocol: Protocol::MQTT(4),
217            keep_alive: 10,
218            client_id: "test".to_owned(),
219            clean_session: true,
220            last_will: Some(LastWill {
221                topic: "/a".to_owned(),
222                message: "offline".to_owned(),
223                retain: false,
224                qos: QoS::AtLeastOnce,
225            }),
226            username: Some("rust".to_owned()),
227            password: Some("mq".to_owned()),
228        }));
229
230        let mut stream = Cursor::new(Vec::new());
231        stream.write_packet(&connect).unwrap();
232
233        assert_eq!(
234            stream.get_ref().clone(),
235            vec![
236                0x10, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04,
237                0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session
238                0x00, 0x0a, // 10 sec
239                0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id
240                0x00, 0x02, '/' as u8, 'a' as u8, // will topic = '/a'
241                0x00, 0x07, 'o' as u8, 'f' as u8, 'f' as u8, 'l' as u8, 'i' as u8, 'n' as u8,
242                'e' as u8, // will msg = 'offline'
243                0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust'
244                0x00, 0x02, 'm' as u8, 'q' as u8 // password = 'mq'
245            ]
246        );
247    }
248
249    #[test]
250    fn write_packet_connect_mqisdp_protocol_test() {
251        let connect = Packet::Connect(Box::new(Connect {
252            protocol: Protocol::MQIsdp(3),
253            keep_alive: 60,
254            client_id: "test".to_owned(),
255            clean_session: false,
256            last_will: None,
257            username: None,
258            password: None,
259        }));
260
261        let mut stream = Cursor::new(Vec::new());
262        stream.write_packet(&connect).unwrap();
263
264        assert_eq!(
265            stream.get_ref().clone(),
266            vec![
267                0x10, 18, 0x00, 0x06, 'M' as u8, 'Q' as u8, 'I' as u8, 's' as u8, 'd' as u8,
268                'p' as u8, 0x03,
269                0b00000000, // -username, -password, -will retain, will qos=0, -last_will, -clean_session
270                0x00, 0x3c, // 60 sec
271                0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8 // client_id
272            ]
273        );
274    }
275
276    #[test]
277    fn write_packet_connack_test() {
278        let connack = Packet::Connack(Connack {
279            session_present: true,
280            code: ConnectReturnCode::Accepted,
281        });
282
283        let mut stream = Cursor::new(Vec::new());
284        stream.write_packet(&connack).unwrap();
285
286        assert_eq!(stream.get_ref().clone(), vec![0b00100000, 0x02, 0x01, 0x00]);
287    }
288
289    #[test]
290    fn write_packet_publish_at_least_once_test() {
291        let publish = Packet::Publish(Box::new(Publish {
292            dup: false,
293            qos: QoS::AtLeastOnce,
294            retain: false,
295            topic_name: "a/b".to_owned(),
296            pid: Some(PacketIdentifier(10)),
297            payload: Arc::new(vec![0xF1, 0xF2, 0xF3, 0xF4]),
298        }));
299
300        let mut stream = Cursor::new(Vec::new());
301        stream.write_packet(&publish).unwrap();
302
303        assert_eq!(
304            stream.get_ref().clone(),
305            vec![
306                0b00110010, 11, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 0x00, 0x0a, 0xF1,
307                0xF2, 0xF3, 0xF4
308            ]
309        );
310    }
311
312    #[test]
313    fn write_packet_publish_at_most_once_test() {
314        let publish = Packet::Publish(Box::new(Publish {
315            dup: false,
316            qos: QoS::AtMostOnce,
317            retain: false,
318            topic_name: "a/b".to_owned(),
319            pid: None,
320            payload: Arc::new(vec![0xE1, 0xE2, 0xE3, 0xE4]),
321        }));
322
323        let mut stream = Cursor::new(Vec::new());
324        stream.write_packet(&publish).unwrap();
325
326        assert_eq!(
327            stream.get_ref().clone(),
328            vec![
329                0b00110000, 9, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 0xE1, 0xE2, 0xE3, 0xE4
330            ]
331        );
332    }
333
334    #[test]
335    fn write_packet_subscribe_test() {
336        let subscribe = Packet::Subscribe(Box::new(Subscribe {
337            pid: PacketIdentifier(260),
338            topics: vec![
339                SubscribeTopic {
340                    topic_path: "a/+".to_owned(),
341                    qos: QoS::AtMostOnce,
342                },
343                SubscribeTopic {
344                    topic_path: "#".to_owned(),
345                    qos: QoS::AtLeastOnce,
346                },
347                SubscribeTopic {
348                    topic_path: "a/b/c".to_owned(),
349                    qos: QoS::ExactlyOnce,
350                },
351            ],
352        }));
353
354        let mut stream = Cursor::new(Vec::new());
355        stream.write_packet(&subscribe).unwrap();
356
357        assert_eq!(
358            stream.get_ref().clone(),
359            vec![
360                0b10000010, 20, 0x01, 0x04, // pid = 260
361                0x00, 0x03, 'a' as u8, '/' as u8, '+' as u8, // topic filter = 'a/+'
362                0x00,      // qos = 0
363                0x00, 0x01, '#' as u8, // topic filter = '#'
364                0x01,      // qos = 1
365                0x00, 0x05, 'a' as u8, '/' as u8, 'b' as u8, '/' as u8,
366                'c' as u8, // topic filter = 'a/b/c'
367                0x02       // qos = 2
368            ]
369        );
370    }
371}