mqtt/packet/
connect.rs

1//! CONNECT
2
3use std::io::{self, Read, Write};
4
5use crate::control::variable_header::protocol_level::SPEC_3_1_1;
6use crate::control::variable_header::{ConnectFlags, KeepAlive, ProtocolLevel, ProtocolName, VariableHeaderError};
7use crate::control::{ControlType, FixedHeader, PacketType};
8use crate::encodable::VarBytes;
9use crate::packet::{DecodablePacket, PacketError};
10use crate::topic_name::{TopicName, TopicNameDecodeError, TopicNameError};
11use crate::{Decodable, Encodable};
12
13/// `CONNECT` packet
14#[derive(Debug, Eq, PartialEq, Clone)]
15pub struct ConnectPacket {
16    fixed_header: FixedHeader,
17    protocol_name: ProtocolName,
18
19    protocol_level: ProtocolLevel,
20    flags: ConnectFlags,
21    keep_alive: KeepAlive,
22
23    payload: ConnectPacketPayload,
24}
25
26encodable_packet!(ConnectPacket(protocol_name, protocol_level, flags, keep_alive, payload));
27
28impl ConnectPacket {
29    pub fn new<C>(client_identifier: C) -> ConnectPacket
30    where
31        C: Into<String>,
32    {
33        ConnectPacket::with_level("MQTT", client_identifier, SPEC_3_1_1).expect("SPEC_3_1_1 should always be valid")
34    }
35
36    pub fn with_level<P, C>(protoname: P, client_identifier: C, level: u8) -> Result<ConnectPacket, VariableHeaderError>
37    where
38        P: Into<String>,
39        C: Into<String>,
40    {
41        let protocol_level = ProtocolLevel::from_u8(level).ok_or(VariableHeaderError::InvalidProtocolVersion)?;
42        let mut pk = ConnectPacket {
43            fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Connect), 0),
44            protocol_name: ProtocolName(protoname.into()),
45            protocol_level,
46            flags: ConnectFlags::empty(),
47            keep_alive: KeepAlive(0),
48            payload: ConnectPacketPayload::new(client_identifier.into()),
49        };
50
51        pk.fix_header_remaining_len();
52
53        Ok(pk)
54    }
55
56    pub fn set_keep_alive(&mut self, keep_alive: u16) {
57        self.keep_alive = KeepAlive(keep_alive);
58    }
59
60    pub fn set_user_name(&mut self, name: Option<String>) {
61        self.flags.user_name = name.is_some();
62        self.payload.user_name = name;
63        self.fix_header_remaining_len();
64    }
65
66    pub fn set_will(&mut self, topic_message: Option<(TopicName, Vec<u8>)>) {
67        self.flags.will_flag = topic_message.is_some();
68
69        self.payload.will = topic_message.map(|(t, m)| (t, VarBytes(m)));
70
71        self.fix_header_remaining_len();
72    }
73
74    pub fn set_password(&mut self, password: Option<String>) {
75        self.flags.password = password.is_some();
76        self.payload.password = password;
77        self.fix_header_remaining_len();
78    }
79
80    pub fn set_client_identifier<I: Into<String>>(&mut self, id: I) {
81        self.payload.client_identifier = id.into();
82        self.fix_header_remaining_len();
83    }
84
85    pub fn set_will_retain(&mut self, will_retain: bool) {
86        self.flags.will_retain = will_retain;
87    }
88
89    pub fn set_will_qos(&mut self, will_qos: u8) {
90        assert!(will_qos <= 2);
91        self.flags.will_qos = will_qos;
92    }
93
94    pub fn set_clean_session(&mut self, clean_session: bool) {
95        self.flags.clean_session = clean_session;
96    }
97
98    pub fn user_name(&self) -> Option<&str> {
99        self.payload.user_name.as_ref().map(|x| &x[..])
100    }
101
102    pub fn password(&self) -> Option<&str> {
103        self.payload.password.as_ref().map(|x| &x[..])
104    }
105
106    pub fn will(&self) -> Option<(&str, &[u8])> {
107        self.payload.will.as_ref().map(|(topic, msg)| (&topic[..], &*msg.0))
108    }
109
110    pub fn will_retain(&self) -> bool {
111        self.flags.will_retain
112    }
113
114    pub fn will_qos(&self) -> u8 {
115        self.flags.will_qos
116    }
117
118    pub fn client_identifier(&self) -> &str {
119        &self.payload.client_identifier[..]
120    }
121
122    pub fn protocol_name(&self) -> &str {
123        &self.protocol_name.0
124    }
125
126    pub fn protocol_level(&self) -> ProtocolLevel {
127        self.protocol_level
128    }
129
130    pub fn clean_session(&self) -> bool {
131        self.flags.clean_session
132    }
133
134    pub fn keep_alive(&self) -> u16 {
135        self.keep_alive.0
136    }
137
138    /// Read back the "reserved" Connect flag bit 0. For compliant implementations this should
139    /// always be false.
140    pub fn reserved_flag(&self) -> bool {
141        self.flags.reserved
142    }
143}
144
145impl DecodablePacket for ConnectPacket {
146    type DecodePacketError = ConnectPacketError;
147
148    fn decode_packet<R: Read>(reader: &mut R, fixed_header: FixedHeader) -> Result<Self, PacketError<Self>> {
149        let protoname: ProtocolName = Decodable::decode(reader)?;
150        let protocol_level: ProtocolLevel = Decodable::decode(reader)?;
151        let flags: ConnectFlags = Decodable::decode(reader)?;
152        let keep_alive: KeepAlive = Decodable::decode(reader)?;
153        let payload: ConnectPacketPayload =
154            Decodable::decode_with(reader, Some(flags)).map_err(PacketError::PayloadError)?;
155
156        Ok(ConnectPacket {
157            fixed_header,
158            protocol_name: protoname,
159            protocol_level,
160            flags,
161            keep_alive,
162            payload,
163        })
164    }
165}
166
167/// Payloads for connect packet
168#[derive(Debug, Eq, PartialEq, Clone)]
169struct ConnectPacketPayload {
170    client_identifier: String,
171    will: Option<(TopicName, VarBytes)>,
172    user_name: Option<String>,
173    password: Option<String>,
174}
175
176impl ConnectPacketPayload {
177    pub fn new(client_identifier: String) -> ConnectPacketPayload {
178        ConnectPacketPayload {
179            client_identifier,
180            will: None,
181            user_name: None,
182            password: None,
183        }
184    }
185}
186
187impl Encodable for ConnectPacketPayload {
188    fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
189        self.client_identifier.encode(writer)?;
190
191        if let Some((will_topic, will_message)) = &self.will {
192            will_topic.encode(writer)?;
193            will_message.encode(writer)?;
194        }
195
196        if let Some(ref user_name) = self.user_name {
197            user_name.encode(writer)?;
198        }
199
200        if let Some(ref password) = self.password {
201            password.encode(writer)?;
202        }
203
204        Ok(())
205    }
206
207    fn encoded_length(&self) -> u32 {
208        self.client_identifier.encoded_length()
209            + self
210                .will
211                .as_ref()
212                .map(|(a, b)| a.encoded_length() + b.encoded_length())
213                .unwrap_or(0)
214            + self.user_name.as_ref().map(|t| t.encoded_length()).unwrap_or(0)
215            + self.password.as_ref().map(|t| t.encoded_length()).unwrap_or(0)
216    }
217}
218
219impl Decodable for ConnectPacketPayload {
220    type Error = ConnectPacketError;
221    type Cond = Option<ConnectFlags>;
222
223    fn decode_with<R: Read>(
224        reader: &mut R,
225        rest: Option<ConnectFlags>,
226    ) -> Result<ConnectPacketPayload, ConnectPacketError> {
227        let mut need_will = false;
228        let mut need_user_name = false;
229        let mut need_password = false;
230
231        if let Some(r) = rest {
232            need_will = r.will_flag;
233            need_user_name = r.user_name;
234            need_password = r.password;
235        }
236
237        let ident = String::decode(reader)?;
238        let will = if need_will {
239            let topic = TopicName::decode(reader).map_err(|e| match e {
240                TopicNameDecodeError::IoError(e) => ConnectPacketError::from(e),
241                TopicNameDecodeError::InvalidTopicName(e) => e.into(),
242            })?;
243            let msg = VarBytes::decode(reader)?;
244            Some((topic, msg))
245        } else {
246            None
247        };
248        let uname = if need_user_name {
249            Some(String::decode(reader)?)
250        } else {
251            None
252        };
253        let pwd = if need_password {
254            Some(String::decode(reader)?)
255        } else {
256            None
257        };
258
259        Ok(ConnectPacketPayload {
260            client_identifier: ident,
261            will,
262            user_name: uname,
263            password: pwd,
264        })
265    }
266}
267
268#[derive(Debug, thiserror::Error)]
269#[error(transparent)]
270pub enum ConnectPacketError {
271    IoError(#[from] io::Error),
272    TopicNameError(#[from] TopicNameError),
273}
274
275#[cfg(test)]
276mod test {
277    use super::*;
278
279    use std::io::Cursor;
280
281    use crate::{Decodable, Encodable};
282
283    #[test]
284    fn test_connect_packet_encode_basic() {
285        let packet = ConnectPacket::new("12345".to_owned());
286        let expected = b"\x10\x11\x00\x04MQTT\x04\x00\x00\x00\x00\x0512345";
287
288        let mut buf = Vec::new();
289        packet.encode(&mut buf).unwrap();
290
291        assert_eq!(&expected[..], &buf[..]);
292    }
293
294    #[test]
295    fn test_connect_packet_decode_basic() {
296        let encoded_data = b"\x10\x11\x00\x04MQTT\x04\x00\x00\x00\x00\x0512345";
297
298        let mut buf = Cursor::new(&encoded_data[..]);
299        let packet = ConnectPacket::decode(&mut buf).unwrap();
300
301        let expected = ConnectPacket::new("12345".to_owned());
302        assert_eq!(expected, packet);
303    }
304
305    #[test]
306    fn test_connect_packet_user_name() {
307        let mut packet = ConnectPacket::new("12345".to_owned());
308        packet.set_user_name(Some("mqtt_player".to_owned()));
309
310        let mut buf = Vec::new();
311        packet.encode(&mut buf).unwrap();
312
313        let mut decode_buf = Cursor::new(buf);
314        let decoded_packet = ConnectPacket::decode(&mut decode_buf).unwrap();
315
316        assert_eq!(packet, decoded_packet);
317    }
318}