1use std::num::{NonZeroU16, NonZeroU32};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use bytestring::ByteString;
5use serde::{Deserialize, Serialize};
6
7use crate::error::{DecodeError, EncodeError};
8use crate::types::{ConnectFlags, QoS, MQTT, MQTT_LEVEL_5, WILL_QOS_SHIFT};
9use crate::utils::{self, Decode, Encode, Property};
10use crate::v5::{encode::*, property_type as pt, UserProperties, UserProperty};
11
12#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
13pub struct Connect {
15 pub clean_start: bool,
17 pub keep_alive: u16,
19
20 pub session_expiry_interval_secs: u32,
21 pub auth_method: Option<ByteString>,
22 pub auth_data: Option<Bytes>,
23 pub request_problem_info: bool,
24 pub request_response_info: bool,
25 pub receive_max: Option<NonZeroU16>,
26 pub topic_alias_max: u16,
27 pub user_properties: UserProperties,
28 pub max_packet_size: Option<NonZeroU32>,
29
30 pub last_will: Option<LastWill>,
32 pub client_id: ByteString,
34 pub username: Option<ByteString>,
36 pub password: Option<Bytes>,
38}
39
40#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
41pub struct LastWill {
43 pub qos: QoS,
45 pub retain: bool,
47 pub topic: ByteString,
49 pub message: Bytes,
51
52 pub will_delay_interval_sec: Option<u32>,
53 pub correlation_data: Option<Bytes>,
54 pub message_expiry_interval: Option<NonZeroU32>,
55 pub content_type: Option<ByteString>,
56 pub user_properties: UserProperties,
57 pub is_utf8_payload: Option<bool>,
58 pub response_topic: Option<ByteString>,
59}
60
61impl LastWill {
62 fn properties_len(&self) -> usize {
63 encoded_property_size(&self.will_delay_interval_sec)
64 + encoded_property_size(&self.correlation_data)
65 + encoded_property_size(&self.message_expiry_interval)
66 + encoded_property_size(&self.content_type)
67 + encoded_property_size(&self.is_utf8_payload)
68 + encoded_property_size(&self.response_topic)
69 + self.user_properties.encoded_size()
70 }
71}
72
73impl Connect {
74 pub fn client_id<T>(mut self, client_id: T) -> Self
76 where
77 ByteString: From<T>,
78 {
79 self.client_id = client_id.into();
80 self
81 }
82
83 pub fn receive_max(mut self, max: u16) -> Self {
85 if let Some(num) = NonZeroU16::new(max) {
86 self.receive_max = Some(num);
87 } else {
88 self.receive_max = None;
89 }
90 self
91 }
92
93 fn properties_len(&self) -> usize {
94 encoded_property_size(&self.auth_method)
95 + encoded_property_size(&self.auth_data)
96 + encoded_property_size_default(&self.session_expiry_interval_secs, 0)
97 + encoded_property_size_default(&self.request_problem_info, true) + encoded_property_size_default(&self.request_response_info, false) + encoded_property_size(&self.receive_max)
100 + encoded_property_size(&self.max_packet_size)
101 + encoded_property_size_default(&self.topic_alias_max, 0)
102 + self.user_properties.encoded_size()
103 }
104
105 pub(crate) fn decode(src: &mut Bytes) -> Result<Self, DecodeError> {
106 ensure!(src.remaining() >= 10, DecodeError::InvalidLength);
107 let len = src.get_u16();
108
109 ensure!(len == 4 && &src.as_ref()[0..4] == MQTT, DecodeError::InvalidProtocol);
110 src.advance(4);
111
112 let level = src.get_u8();
113 ensure!(level == MQTT_LEVEL_5, DecodeError::UnsupportedProtocolLevel);
114
115 let flags = ConnectFlags::from_bits(src.get_u8()).ok_or(DecodeError::ConnectReservedFlagSet)?;
116 let keep_alive = src.get_u16();
117
118 let mut session_expiry_interval_secs = None;
120 let mut auth_method = None;
121 let mut auth_data = None;
122 let mut request_problem_info = None;
123 let mut request_response_info = None;
124 let mut receive_max = None;
125 let mut topic_alias_max = None;
126 let mut user_properties = Vec::new();
127 let mut max_packet_size = None;
128 let prop_src = &mut utils::take_properties(src)?;
129 while prop_src.has_remaining() {
130 match prop_src.get_u8() {
131 pt::SESS_EXPIRY_INT => session_expiry_interval_secs.read_value(prop_src)?,
132 pt::AUTH_METHOD => auth_method.read_value(prop_src)?,
133 pt::AUTH_DATA => auth_data.read_value(prop_src)?,
134 pt::REQ_PROB_INFO => request_problem_info.read_value(prop_src)?,
135 pt::REQ_RESP_INFO => request_response_info.read_value(prop_src)?,
136 pt::RECEIVE_MAX => receive_max.read_value(prop_src)?,
137 pt::TOPIC_ALIAS_MAX => topic_alias_max.read_value(prop_src)?,
138 pt::USER => user_properties.push(UserProperty::decode(prop_src)?),
139 pt::MAX_PACKET_SIZE => max_packet_size.read_value(prop_src)?,
140 _ => return Err(DecodeError::MalformedPacket),
141 }
142 }
143
144 let client_id = ByteString::decode(src)?;
145
146 ensure!(
147 !client_id.is_empty() || flags.contains(ConnectFlags::CLEAN_START),
149 DecodeError::InvalidClientId
150 );
151
152 let last_will =
153 if flags.contains(ConnectFlags::WILL) { Some(decode_last_will(src, flags)?) } else { None };
154
155 let username =
156 if flags.contains(ConnectFlags::USERNAME) { Some(ByteString::decode(src)?) } else { None };
157 let password = if flags.contains(ConnectFlags::PASSWORD) { Some(Bytes::decode(src)?) } else { None };
158
159 Ok(Connect {
160 clean_start: flags.contains(ConnectFlags::CLEAN_START),
161 keep_alive,
162 session_expiry_interval_secs: session_expiry_interval_secs.unwrap_or(0),
163 auth_method,
164 auth_data,
165 receive_max,
166 topic_alias_max: topic_alias_max.unwrap_or(0u16),
167 request_problem_info: request_problem_info.unwrap_or(true),
168 request_response_info: request_response_info.unwrap_or(false),
169 user_properties,
170 max_packet_size,
171
172 client_id,
173 last_will,
174 username,
175 password,
176 })
177 }
178}
179
180impl Default for Connect {
181 fn default() -> Connect {
182 Connect {
183 clean_start: false,
184 keep_alive: 0,
185 session_expiry_interval_secs: 0,
186 auth_method: None,
187 auth_data: None,
188 request_problem_info: true,
189 request_response_info: false,
190 receive_max: None,
191 topic_alias_max: 0,
192 user_properties: Vec::new(),
193 max_packet_size: None,
194 last_will: None,
195 client_id: ByteString::default(),
196 username: None,
197 password: None,
198 }
199 }
200}
201
202fn decode_last_will(src: &mut Bytes, flags: ConnectFlags) -> Result<LastWill, DecodeError> {
203 let mut will_delay_interval_sec = None;
204 let mut correlation_data = None;
205 let mut message_expiry_interval = None;
206 let mut content_type = None;
207 let mut user_properties = Vec::new();
208 let mut is_utf8_payload = None;
209 let mut response_topic = None;
210 let prop_src = &mut utils::take_properties(src)?;
211 while prop_src.has_remaining() {
212 match prop_src.get_u8() {
213 pt::WILL_DELAY_INT => will_delay_interval_sec.read_value(prop_src)?,
214 pt::CORR_DATA => correlation_data.read_value(prop_src)?,
215 pt::MSG_EXPIRY_INT => message_expiry_interval.read_value(prop_src)?,
216 pt::CONTENT_TYPE => content_type.read_value(prop_src)?,
217 pt::UTF8_PAYLOAD => is_utf8_payload.read_value(prop_src)?,
218 pt::RESP_TOPIC => response_topic.read_value(prop_src)?,
219 pt::USER => user_properties.push(UserProperty::decode(prop_src)?),
220 _ => return Err(DecodeError::MalformedPacket),
221 }
222 }
223
224 let topic = ByteString::decode(src)?;
225 let message = Bytes::decode(src)?;
226 Ok(LastWill {
227 qos: QoS::try_from((flags & ConnectFlags::WILL_QOS).bits() >> WILL_QOS_SHIFT)?,
228 retain: flags.contains(ConnectFlags::WILL_RETAIN),
229 topic,
230 message,
231 will_delay_interval_sec,
232 correlation_data,
233 message_expiry_interval,
234 content_type,
235 user_properties,
236 is_utf8_payload,
237 response_topic,
238 })
239}
240
241impl EncodeLtd for Connect {
242 fn encoded_size(&self, _limit: u32) -> usize {
243 let prop_len = self.properties_len();
244 6 + 1 + 1 + 2 + var_int_len(prop_len) as usize + prop_len + self.client_id.encoded_size()
251 + self.last_will.as_ref().map_or(0, |will| { let prop_len = will.properties_len();
253 var_int_len(prop_len) as usize + prop_len + will.topic.encoded_size() + will.message.encoded_size()
254 })
255 + self.username.as_ref().map_or(0, |v| v.encoded_size())
256 + self.password.as_ref().map_or(0, |v| v.encoded_size())
257 }
258
259 fn encode(&self, buf: &mut BytesMut, _size: u32) -> Result<(), EncodeError> {
260 b"MQTT".as_ref().encode(buf)?;
261
262 let mut flags = ConnectFlags::empty();
263
264 if self.username.is_some() {
265 flags |= ConnectFlags::USERNAME;
266 }
267 if self.password.is_some() {
268 flags |= ConnectFlags::PASSWORD;
269 }
270
271 if let Some(will) = self.last_will.as_ref() {
272 flags |= ConnectFlags::WILL;
273
274 if will.retain {
275 flags |= ConnectFlags::WILL_RETAIN;
276 }
277
278 flags |= ConnectFlags::from_bits_truncate(u8::from(will.qos) << WILL_QOS_SHIFT);
279 }
280
281 if self.clean_start {
282 flags |= ConnectFlags::CLEAN_START;
283 }
284
285 buf.put_slice(&[MQTT_LEVEL_5, flags.bits()]);
286
287 self.keep_alive.encode(buf)?;
288
289 let prop_len = self.properties_len();
290 utils::write_variable_length(prop_len as u32, buf); encode_property_default(&self.session_expiry_interval_secs, 0, pt::SESS_EXPIRY_INT, buf)?;
293 encode_property(&self.auth_method, pt::AUTH_METHOD, buf)?;
294 encode_property(&self.auth_data, pt::AUTH_DATA, buf)?;
295 encode_property_default(&self.request_problem_info, true, pt::REQ_PROB_INFO, buf)?; encode_property_default(&self.request_response_info, false, pt::REQ_RESP_INFO, buf)?; encode_property(&self.receive_max, pt::RECEIVE_MAX, buf)?;
298 encode_property(&self.max_packet_size, pt::MAX_PACKET_SIZE, buf)?;
299 encode_property_default(&self.topic_alias_max, 0, pt::TOPIC_ALIAS_MAX, buf)?;
300 self.user_properties.encode(buf)?;
301
302 self.client_id.encode(buf)?;
303
304 if let Some(will) = self.last_will.as_ref() {
305 let prop_len = will.properties_len();
306 utils::write_variable_length(prop_len as u32, buf); encode_property(&will.will_delay_interval_sec, pt::WILL_DELAY_INT, buf)?;
309 encode_property(&will.is_utf8_payload, pt::UTF8_PAYLOAD, buf)?;
310 encode_property(&will.message_expiry_interval, pt::MSG_EXPIRY_INT, buf)?;
311 encode_property(&will.content_type, pt::CONTENT_TYPE, buf)?;
312 encode_property(&will.response_topic, pt::RESP_TOPIC, buf)?;
313 encode_property(&will.correlation_data, pt::CORR_DATA, buf)?;
314 will.user_properties.encode(buf)?;
315 will.topic.encode(buf)?;
316 will.message.encode(buf)?;
317 }
318 if let Some(s) = self.username.as_ref() {
319 s.encode(buf)?;
320 }
321 if let Some(pwd) = self.password.as_ref() {
322 pwd.encode(buf)?;
323 }
324 Ok(())
325 }
326}