1use std::io::{self, Read, Write};
4
5use crate::control::variable_header::protocol_level::SPEC_3_1_1;
6use crate::control::variable_header::{ConnectFlags, KeepAlive, ProtocolLevel, ProtocolName, VariableHeaderError};
7use crate::control::{ControlType, FixedHeader, PacketType};
8use crate::encodable::VarBytes;
9use crate::packet::{DecodablePacket, PacketError};
10use crate::topic_name::{TopicName, TopicNameDecodeError, TopicNameError};
11use crate::{Decodable, Encodable};
12
13#[derive(Debug, Eq, PartialEq, Clone)]
15pub struct ConnectPacket {
16 fixed_header: FixedHeader,
17 protocol_name: ProtocolName,
18
19 protocol_level: ProtocolLevel,
20 flags: ConnectFlags,
21 keep_alive: KeepAlive,
22
23 payload: ConnectPacketPayload,
24}
25
26encodable_packet!(ConnectPacket(protocol_name, protocol_level, flags, keep_alive, payload));
27
28impl ConnectPacket {
29 pub fn new<C>(client_identifier: C) -> ConnectPacket
30 where
31 C: Into<String>,
32 {
33 ConnectPacket::with_level("MQTT", client_identifier, SPEC_3_1_1).expect("SPEC_3_1_1 should always be valid")
34 }
35
36 pub fn with_level<P, C>(protoname: P, client_identifier: C, level: u8) -> Result<ConnectPacket, VariableHeaderError>
37 where
38 P: Into<String>,
39 C: Into<String>,
40 {
41 let protocol_level = ProtocolLevel::from_u8(level).ok_or(VariableHeaderError::InvalidProtocolVersion)?;
42 let mut pk = ConnectPacket {
43 fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Connect), 0),
44 protocol_name: ProtocolName(protoname.into()),
45 protocol_level,
46 flags: ConnectFlags::empty(),
47 keep_alive: KeepAlive(0),
48 payload: ConnectPacketPayload::new(client_identifier.into()),
49 };
50
51 pk.fix_header_remaining_len();
52
53 Ok(pk)
54 }
55
56 pub fn set_keep_alive(&mut self, keep_alive: u16) {
57 self.keep_alive = KeepAlive(keep_alive);
58 }
59
60 pub fn set_user_name(&mut self, name: Option<String>) {
61 self.flags.user_name = name.is_some();
62 self.payload.user_name = name;
63 self.fix_header_remaining_len();
64 }
65
66 pub fn set_will(&mut self, topic_message: Option<(TopicName, Vec<u8>)>) {
67 self.flags.will_flag = topic_message.is_some();
68
69 self.payload.will = topic_message.map(|(t, m)| (t, VarBytes(m)));
70
71 self.fix_header_remaining_len();
72 }
73
74 pub fn set_password(&mut self, password: Option<String>) {
75 self.flags.password = password.is_some();
76 self.payload.password = password;
77 self.fix_header_remaining_len();
78 }
79
80 pub fn set_client_identifier<I: Into<String>>(&mut self, id: I) {
81 self.payload.client_identifier = id.into();
82 self.fix_header_remaining_len();
83 }
84
85 pub fn set_will_retain(&mut self, will_retain: bool) {
86 self.flags.will_retain = will_retain;
87 }
88
89 pub fn set_will_qos(&mut self, will_qos: u8) {
90 assert!(will_qos <= 2);
91 self.flags.will_qos = will_qos;
92 }
93
94 pub fn set_clean_session(&mut self, clean_session: bool) {
95 self.flags.clean_session = clean_session;
96 }
97
98 pub fn user_name(&self) -> Option<&str> {
99 self.payload.user_name.as_ref().map(|x| &x[..])
100 }
101
102 pub fn password(&self) -> Option<&str> {
103 self.payload.password.as_ref().map(|x| &x[..])
104 }
105
106 pub fn will(&self) -> Option<(&str, &[u8])> {
107 self.payload.will.as_ref().map(|(topic, msg)| (&topic[..], &*msg.0))
108 }
109
110 pub fn will_retain(&self) -> bool {
111 self.flags.will_retain
112 }
113
114 pub fn will_qos(&self) -> u8 {
115 self.flags.will_qos
116 }
117
118 pub fn client_identifier(&self) -> &str {
119 &self.payload.client_identifier[..]
120 }
121
122 pub fn protocol_name(&self) -> &str {
123 &self.protocol_name.0
124 }
125
126 pub fn protocol_level(&self) -> ProtocolLevel {
127 self.protocol_level
128 }
129
130 pub fn clean_session(&self) -> bool {
131 self.flags.clean_session
132 }
133
134 pub fn keep_alive(&self) -> u16 {
135 self.keep_alive.0
136 }
137
138 pub fn reserved_flag(&self) -> bool {
141 self.flags.reserved
142 }
143}
144
145impl DecodablePacket for ConnectPacket {
146 type DecodePacketError = ConnectPacketError;
147
148 fn decode_packet<R: Read>(reader: &mut R, fixed_header: FixedHeader) -> Result<Self, PacketError<Self>> {
149 let protoname: ProtocolName = Decodable::decode(reader)?;
150 let protocol_level: ProtocolLevel = Decodable::decode(reader)?;
151 let flags: ConnectFlags = Decodable::decode(reader)?;
152 let keep_alive: KeepAlive = Decodable::decode(reader)?;
153 let payload: ConnectPacketPayload =
154 Decodable::decode_with(reader, Some(flags)).map_err(PacketError::PayloadError)?;
155
156 Ok(ConnectPacket {
157 fixed_header,
158 protocol_name: protoname,
159 protocol_level,
160 flags,
161 keep_alive,
162 payload,
163 })
164 }
165}
166
167#[derive(Debug, Eq, PartialEq, Clone)]
169struct ConnectPacketPayload {
170 client_identifier: String,
171 will: Option<(TopicName, VarBytes)>,
172 user_name: Option<String>,
173 password: Option<String>,
174}
175
176impl ConnectPacketPayload {
177 pub fn new(client_identifier: String) -> ConnectPacketPayload {
178 ConnectPacketPayload {
179 client_identifier,
180 will: None,
181 user_name: None,
182 password: None,
183 }
184 }
185}
186
187impl Encodable for ConnectPacketPayload {
188 fn encode<W: Write>(&self, writer: &mut W) -> Result<(), io::Error> {
189 self.client_identifier.encode(writer)?;
190
191 if let Some((will_topic, will_message)) = &self.will {
192 will_topic.encode(writer)?;
193 will_message.encode(writer)?;
194 }
195
196 if let Some(ref user_name) = self.user_name {
197 user_name.encode(writer)?;
198 }
199
200 if let Some(ref password) = self.password {
201 password.encode(writer)?;
202 }
203
204 Ok(())
205 }
206
207 fn encoded_length(&self) -> u32 {
208 self.client_identifier.encoded_length()
209 + self
210 .will
211 .as_ref()
212 .map(|(a, b)| a.encoded_length() + b.encoded_length())
213 .unwrap_or(0)
214 + self.user_name.as_ref().map(|t| t.encoded_length()).unwrap_or(0)
215 + self.password.as_ref().map(|t| t.encoded_length()).unwrap_or(0)
216 }
217}
218
219impl Decodable for ConnectPacketPayload {
220 type Error = ConnectPacketError;
221 type Cond = Option<ConnectFlags>;
222
223 fn decode_with<R: Read>(
224 reader: &mut R,
225 rest: Option<ConnectFlags>,
226 ) -> Result<ConnectPacketPayload, ConnectPacketError> {
227 let mut need_will = false;
228 let mut need_user_name = false;
229 let mut need_password = false;
230
231 if let Some(r) = rest {
232 need_will = r.will_flag;
233 need_user_name = r.user_name;
234 need_password = r.password;
235 }
236
237 let ident = String::decode(reader)?;
238 let will = if need_will {
239 let topic = TopicName::decode(reader).map_err(|e| match e {
240 TopicNameDecodeError::IoError(e) => ConnectPacketError::from(e),
241 TopicNameDecodeError::InvalidTopicName(e) => e.into(),
242 })?;
243 let msg = VarBytes::decode(reader)?;
244 Some((topic, msg))
245 } else {
246 None
247 };
248 let uname = if need_user_name {
249 Some(String::decode(reader)?)
250 } else {
251 None
252 };
253 let pwd = if need_password {
254 Some(String::decode(reader)?)
255 } else {
256 None
257 };
258
259 Ok(ConnectPacketPayload {
260 client_identifier: ident,
261 will,
262 user_name: uname,
263 password: pwd,
264 })
265 }
266}
267
268#[derive(Debug, thiserror::Error)]
269#[error(transparent)]
270pub enum ConnectPacketError {
271 IoError(#[from] io::Error),
272 TopicNameError(#[from] TopicNameError),
273}
274
275#[cfg(test)]
276mod test {
277 use super::*;
278
279 use std::io::Cursor;
280
281 use crate::{Decodable, Encodable};
282
283 #[test]
284 fn test_connect_packet_encode_basic() {
285 let packet = ConnectPacket::new("12345".to_owned());
286 let expected = b"\x10\x11\x00\x04MQTT\x04\x00\x00\x00\x00\x0512345";
287
288 let mut buf = Vec::new();
289 packet.encode(&mut buf).unwrap();
290
291 assert_eq!(&expected[..], &buf[..]);
292 }
293
294 #[test]
295 fn test_connect_packet_decode_basic() {
296 let encoded_data = b"\x10\x11\x00\x04MQTT\x04\x00\x00\x00\x00\x0512345";
297
298 let mut buf = Cursor::new(&encoded_data[..]);
299 let packet = ConnectPacket::decode(&mut buf).unwrap();
300
301 let expected = ConnectPacket::new("12345".to_owned());
302 assert_eq!(expected, packet);
303 }
304
305 #[test]
306 fn test_connect_packet_user_name() {
307 let mut packet = ConnectPacket::new("12345".to_owned());
308 packet.set_user_name(Some("mqtt_player".to_owned()));
309
310 let mut buf = Vec::new();
311 packet.encode(&mut buf).unwrap();
312
313 let mut decode_buf = Cursor::new(buf);
314 let decoded_packet = ConnectPacket::decode(&mut decode_buf).unwrap();
315
316 assert_eq!(packet, decoded_packet);
317 }
318}