1use std::num::{NonZeroU16, NonZeroU32};
2
3use ntex_bytes::{Buf, BufMut, ByteString, Bytes, BytesMut};
4
5use crate::error::{DecodeError, EncodeError};
6use crate::types::{ConnectFlags, MQTT, MQTT_LEVEL_5, QoS, WILL_QOS_SHIFT};
7use crate::utils::{self, Decode, Encode, Property};
8use crate::v5::codec::{UserProperties, UserProperty, encode::*, property_type as pt};
9
10#[derive(Debug, PartialEq, Eq, Clone)]
11pub struct Connect {
13 pub clean_start: bool,
15 pub keep_alive: u16,
17
18 pub session_expiry_interval_secs: u32,
19 pub auth_method: Option<ByteString>,
20 pub auth_data: Option<Bytes>,
21 pub request_problem_info: bool,
22 pub request_response_info: bool,
23 pub receive_max: Option<NonZeroU16>,
24 pub topic_alias_max: u16,
25 pub user_properties: UserProperties,
26 pub max_packet_size: Option<NonZeroU32>,
27
28 pub last_will: Option<LastWill>,
30 pub client_id: ByteString,
32 pub username: Option<ByteString>,
34 pub password: Option<Bytes>,
36}
37
38#[derive(Debug, PartialEq, Eq, Clone)]
39pub struct LastWill {
41 pub qos: QoS,
43 pub retain: bool,
45 pub topic: ByteString,
47 pub message: Bytes,
49
50 pub will_delay_interval_sec: Option<u32>,
51 pub correlation_data: Option<Bytes>,
52 pub message_expiry_interval: Option<NonZeroU32>,
53 pub content_type: Option<ByteString>,
54 pub user_properties: UserProperties,
55 pub is_utf8_payload: Option<bool>,
56 pub response_topic: Option<ByteString>,
57}
58
59impl LastWill {
60 fn properties_len(&self) -> usize {
61 encoded_property_size(&self.will_delay_interval_sec)
62 + encoded_property_size(&self.correlation_data)
63 + encoded_property_size(&self.message_expiry_interval)
64 + encoded_property_size(&self.content_type)
65 + encoded_property_size(&self.is_utf8_payload)
66 + encoded_property_size(&self.response_topic)
67 + self.user_properties.encoded_size()
68 }
69}
70
71impl Connect {
72 pub fn client_id<T>(mut self, client_id: T) -> Self
74 where
75 ByteString: From<T>,
76 {
77 self.client_id = client_id.into();
78 self
79 }
80
81 pub fn receive_max(mut self, max: u16) -> Self {
83 if let Some(num) = NonZeroU16::new(max) {
84 self.receive_max = Some(num);
85 } else {
86 self.receive_max = None;
87 }
88 self
89 }
90
91 fn properties_len(&self) -> usize {
92 encoded_property_size(&self.auth_method)
93 + encoded_property_size(&self.auth_data)
94 + encoded_property_size_default(&self.session_expiry_interval_secs, 0)
95 + encoded_property_size_default(&self.request_problem_info, true) + encoded_property_size_default(&self.request_response_info, false) + encoded_property_size(&self.receive_max)
98 + encoded_property_size(&self.max_packet_size)
99 + encoded_property_size_default(&self.topic_alias_max, 0)
100 + self.user_properties.encoded_size()
101 }
102
103 pub(crate) fn decode(src: &mut Bytes) -> Result<Self, DecodeError> {
104 ensure!(src.remaining() >= 10, DecodeError::InvalidLength);
105 let len = src.get_u16();
106
107 ensure!(len == 4 && &src.as_ref()[0..4] == MQTT, DecodeError::InvalidProtocol);
108 src.advance(4);
109
110 let level = src.get_u8();
111 ensure!(level == MQTT_LEVEL_5, DecodeError::UnsupportedProtocolLevel);
112
113 let flags =
114 ConnectFlags::from_bits(src.get_u8()).ok_or(DecodeError::ConnectReservedFlagSet)?;
115 let keep_alive = src.get_u16();
116
117 let mut session_expiry_interval_secs = None;
119 let mut auth_method = None;
120 let mut auth_data = None;
121 let mut request_problem_info = None;
122 let mut request_response_info = None;
123 let mut receive_max = None;
124 let mut topic_alias_max = None;
125 let mut user_properties = Vec::new();
126 let mut max_packet_size = None;
127 let prop_src = &mut utils::take_properties(src)?;
128 while prop_src.has_remaining() {
129 match prop_src.get_u8() {
130 pt::SESS_EXPIRY_INT => session_expiry_interval_secs.read_value(prop_src)?,
131 pt::AUTH_METHOD => auth_method.read_value(prop_src)?,
132 pt::AUTH_DATA => auth_data.read_value(prop_src)?,
133 pt::REQ_PROB_INFO => request_problem_info.read_value(prop_src)?,
134 pt::REQ_RESP_INFO => request_response_info.read_value(prop_src)?,
135 pt::RECEIVE_MAX => receive_max.read_value(prop_src)?,
136 pt::TOPIC_ALIAS_MAX => topic_alias_max.read_value(prop_src)?,
137 pt::USER => user_properties.push(UserProperty::decode(prop_src)?),
138 pt::MAX_PACKET_SIZE => max_packet_size.read_value(prop_src)?,
139 _ => return Err(DecodeError::MalformedPacket),
140 }
141 }
142
143 let client_id = ByteString::decode(src)?;
144
145 let last_will = if flags.contains(ConnectFlags::WILL) {
146 Some(decode_last_will(src, flags)?)
147 } else {
148 None
149 };
150
151 let username = if flags.contains(ConnectFlags::USERNAME) {
152 Some(ByteString::decode(src)?)
153 } else {
154 None
155 };
156 let password = if flags.contains(ConnectFlags::PASSWORD) {
157 Some(Bytes::decode(src)?)
158 } else {
159 None
160 };
161
162 Ok(Connect {
163 clean_start: flags.contains(ConnectFlags::CLEAN_START),
164 keep_alive,
165 session_expiry_interval_secs: session_expiry_interval_secs.unwrap_or(0),
166 auth_method,
167 auth_data,
168 receive_max,
169 topic_alias_max: topic_alias_max.unwrap_or(0u16),
170 request_problem_info: request_problem_info.unwrap_or(true),
171 request_response_info: request_response_info.unwrap_or(false),
172 user_properties,
173 max_packet_size,
174
175 client_id,
176 last_will,
177 username,
178 password,
179 })
180 }
181}
182
183impl Default for Connect {
184 fn default() -> Connect {
185 Connect {
186 clean_start: false,
187 keep_alive: 0,
188 session_expiry_interval_secs: 0,
189 auth_method: None,
190 auth_data: None,
191 request_problem_info: true,
192 request_response_info: false,
193 receive_max: None,
194 topic_alias_max: 0,
195 user_properties: Vec::new(),
196 max_packet_size: None,
197 last_will: None,
198 client_id: ByteString::default(),
199 username: None,
200 password: None,
201 }
202 }
203}
204
205fn decode_last_will(src: &mut Bytes, flags: ConnectFlags) -> Result<LastWill, DecodeError> {
206 let mut will_delay_interval_sec = None;
207 let mut correlation_data = None;
208 let mut message_expiry_interval = None;
209 let mut content_type = None;
210 let mut user_properties = Vec::new();
211 let mut is_utf8_payload = None;
212 let mut response_topic = None;
213 let prop_src = &mut utils::take_properties(src)?;
214 while prop_src.has_remaining() {
215 match prop_src.get_u8() {
216 pt::WILL_DELAY_INT => will_delay_interval_sec.read_value(prop_src)?,
217 pt::CORR_DATA => correlation_data.read_value(prop_src)?,
218 pt::MSG_EXPIRY_INT => message_expiry_interval.read_value(prop_src)?,
219 pt::CONTENT_TYPE => content_type.read_value(prop_src)?,
220 pt::UTF8_PAYLOAD => is_utf8_payload.read_value(prop_src)?,
221 pt::RESP_TOPIC => response_topic.read_value(prop_src)?,
222 pt::USER => user_properties.push(UserProperty::decode(prop_src)?),
223 _ => return Err(DecodeError::MalformedPacket),
224 }
225 }
226
227 let topic = ByteString::decode(src)?;
228 let message = Bytes::decode(src)?;
229 Ok(LastWill {
230 qos: QoS::try_from((flags & ConnectFlags::WILL_QOS).bits() >> WILL_QOS_SHIFT)?,
231 retain: flags.contains(ConnectFlags::WILL_RETAIN),
232 topic,
233 message,
234 will_delay_interval_sec,
235 correlation_data,
236 message_expiry_interval,
237 content_type,
238 user_properties,
239 is_utf8_payload,
240 response_topic,
241 })
242}
243
244impl EncodeLtd for Connect {
245 fn encoded_size(&self, _limit: u32) -> usize {
246 let prop_len = self.properties_len();
247 6 + 1 + 1 + 2 + var_int_len(prop_len) as usize + prop_len + self.client_id.encoded_size()
254 + self.last_will.as_ref().map_or(0, |will| { let prop_len = will.properties_len();
256 var_int_len(prop_len) as usize + prop_len + will.topic.encoded_size() + will.message.encoded_size()
257 })
258 + self.username.as_ref().map_or(0, |v| v.encoded_size())
259 + self.password.as_ref().map_or(0, |v| v.encoded_size())
260 }
261
262 fn encode(&self, buf: &mut BytesMut, _size: u32) -> Result<(), EncodeError> {
263 b"MQTT".as_ref().encode(buf)?;
264
265 let mut flags = ConnectFlags::empty();
266
267 if self.username.is_some() {
268 flags |= ConnectFlags::USERNAME;
269 }
270 if self.password.is_some() {
271 flags |= ConnectFlags::PASSWORD;
272 }
273
274 if let Some(will) = self.last_will.as_ref() {
275 flags |= ConnectFlags::WILL;
276
277 if will.retain {
278 flags |= ConnectFlags::WILL_RETAIN;
279 }
280
281 flags |= ConnectFlags::from_bits_truncate(u8::from(will.qos) << WILL_QOS_SHIFT);
282 }
283
284 if self.clean_start {
285 flags |= ConnectFlags::CLEAN_START;
286 }
287
288 buf.put_slice(&[MQTT_LEVEL_5, flags.bits()]);
289
290 self.keep_alive.encode(buf)?;
291
292 let prop_len = self.properties_len();
293 utils::write_variable_length(prop_len as u32, buf); encode_property_default(
296 &self.session_expiry_interval_secs,
297 0,
298 pt::SESS_EXPIRY_INT,
299 buf,
300 )?;
301 encode_property(&self.auth_method, pt::AUTH_METHOD, buf)?;
302 encode_property(&self.auth_data, pt::AUTH_DATA, buf)?;
303 encode_property_default(&self.request_problem_info, true, pt::REQ_PROB_INFO, buf)?; encode_property_default(&self.request_response_info, false, pt::REQ_RESP_INFO, buf)?; encode_property(&self.receive_max, pt::RECEIVE_MAX, buf)?;
306 encode_property(&self.max_packet_size, pt::MAX_PACKET_SIZE, buf)?;
307 encode_property_default(&self.topic_alias_max, 0, pt::TOPIC_ALIAS_MAX, buf)?;
308 self.user_properties.encode(buf)?;
309
310 self.client_id.encode(buf)?;
311
312 if let Some(will) = self.last_will.as_ref() {
313 let prop_len = will.properties_len();
314 utils::write_variable_length(prop_len as u32, buf); encode_property(&will.will_delay_interval_sec, pt::WILL_DELAY_INT, buf)?;
317 encode_property(&will.is_utf8_payload, pt::UTF8_PAYLOAD, buf)?;
318 encode_property(&will.message_expiry_interval, pt::MSG_EXPIRY_INT, buf)?;
319 encode_property(&will.content_type, pt::CONTENT_TYPE, buf)?;
320 encode_property(&will.response_topic, pt::RESP_TOPIC, buf)?;
321 encode_property(&will.correlation_data, pt::CORR_DATA, buf)?;
322 will.user_properties.encode(buf)?;
323 will.topic.encode(buf)?;
324 will.message.encode(buf)?;
325 }
326 if let Some(s) = self.username.as_ref() {
327 s.encode(buf)?;
328 }
329 if let Some(pwd) = self.password.as_ref() {
330 pwd.encode(buf)?;
331 }
332 Ok(())
333 }
334}