ntex_mqtt/v5/codec/packet/
connect.rs

1use std::num::{NonZeroU16, NonZeroU32};
2
3use ntex_bytes::{Buf, BufMut, ByteString, Bytes, BytesMut};
4
5use crate::error::{DecodeError, EncodeError};
6use crate::types::{ConnectFlags, MQTT, MQTT_LEVEL_5, QoS, WILL_QOS_SHIFT};
7use crate::utils::{self, Decode, Encode, Property};
8use crate::v5::codec::{UserProperties, UserProperty, encode::*, property_type as pt};
9
10#[derive(Debug, PartialEq, Eq, Clone)]
11/// Connect packet content
12pub struct Connect {
13    /// the handling of the Session state.
14    pub clean_start: bool,
15    /// a time interval measured in seconds.
16    pub keep_alive: u16,
17
18    pub session_expiry_interval_secs: u32,
19    pub auth_method: Option<ByteString>,
20    pub auth_data: Option<Bytes>,
21    pub request_problem_info: bool,
22    pub request_response_info: bool,
23    pub receive_max: Option<NonZeroU16>,
24    pub topic_alias_max: u16,
25    pub user_properties: UserProperties,
26    pub max_packet_size: Option<NonZeroU32>,
27
28    /// Will Message be stored on the Server and associated with the Network Connection.
29    pub last_will: Option<LastWill>,
30    /// identifies the Client to the Server.
31    pub client_id: ByteString,
32    /// username can be used by the Server for authentication and authorization.
33    pub username: Option<ByteString>,
34    /// password can be used by the Server for authentication and authorization.
35    pub password: Option<Bytes>,
36}
37
38#[derive(Debug, PartialEq, Eq, Clone)]
39/// Connection Will
40pub struct LastWill {
41    /// the QoS level to be used when publishing the Will Message.
42    pub qos: QoS,
43    /// the Will Message is to be Retained when it is published.
44    pub retain: bool,
45    /// the Will Topic
46    pub topic: ByteString,
47    /// defines the Application Message that is to be published to the Will Topic
48    pub message: Bytes,
49
50    pub will_delay_interval_sec: Option<u32>,
51    pub correlation_data: Option<Bytes>,
52    pub message_expiry_interval: Option<NonZeroU32>,
53    pub content_type: Option<ByteString>,
54    pub user_properties: UserProperties,
55    pub is_utf8_payload: Option<bool>,
56    pub response_topic: Option<ByteString>,
57}
58
59impl LastWill {
60    fn properties_len(&self) -> usize {
61        encoded_property_size(&self.will_delay_interval_sec)
62            + encoded_property_size(&self.correlation_data)
63            + encoded_property_size(&self.message_expiry_interval)
64            + encoded_property_size(&self.content_type)
65            + encoded_property_size(&self.is_utf8_payload)
66            + encoded_property_size(&self.response_topic)
67            + self.user_properties.encoded_size()
68    }
69}
70
71impl Connect {
72    /// Set client_id value
73    pub fn client_id<T>(mut self, client_id: T) -> Self
74    where
75        ByteString: From<T>,
76    {
77        self.client_id = client_id.into();
78        self
79    }
80
81    /// Set receive_max value
82    pub fn receive_max(mut self, max: u16) -> Self {
83        if let Some(num) = NonZeroU16::new(max) {
84            self.receive_max = Some(num);
85        } else {
86            self.receive_max = None;
87        }
88        self
89    }
90
91    fn properties_len(&self) -> usize {
92        encoded_property_size(&self.auth_method)
93            + encoded_property_size(&self.auth_data)
94            + encoded_property_size_default(&self.session_expiry_interval_secs, 0)
95            + encoded_property_size_default(&self.request_problem_info, true) // 3.1.2.11.7 Request Problem Information
96            + encoded_property_size_default(&self.request_response_info, false) // 3.1.2.11.6 Request Response Information
97            + encoded_property_size(&self.receive_max)
98            + encoded_property_size(&self.max_packet_size)
99            + encoded_property_size_default(&self.topic_alias_max, 0)
100            + self.user_properties.encoded_size()
101    }
102
103    pub(crate) fn decode(src: &mut Bytes) -> Result<Self, DecodeError> {
104        ensure!(src.remaining() >= 10, DecodeError::InvalidLength);
105        let len = src.get_u16();
106
107        ensure!(len == 4 && &src.as_ref()[0..4] == MQTT, DecodeError::InvalidProtocol);
108        src.advance(4);
109
110        let level = src.get_u8();
111        ensure!(level == MQTT_LEVEL_5, DecodeError::UnsupportedProtocolLevel);
112
113        let flags =
114            ConnectFlags::from_bits(src.get_u8()).ok_or(DecodeError::ConnectReservedFlagSet)?;
115        let keep_alive = src.get_u16();
116
117        // reading properties
118        let mut session_expiry_interval_secs = None;
119        let mut auth_method = None;
120        let mut auth_data = None;
121        let mut request_problem_info = None;
122        let mut request_response_info = None;
123        let mut receive_max = None;
124        let mut topic_alias_max = None;
125        let mut user_properties = Vec::new();
126        let mut max_packet_size = None;
127        let prop_src = &mut utils::take_properties(src)?;
128        while prop_src.has_remaining() {
129            match prop_src.get_u8() {
130                pt::SESS_EXPIRY_INT => session_expiry_interval_secs.read_value(prop_src)?,
131                pt::AUTH_METHOD => auth_method.read_value(prop_src)?,
132                pt::AUTH_DATA => auth_data.read_value(prop_src)?,
133                pt::REQ_PROB_INFO => request_problem_info.read_value(prop_src)?,
134                pt::REQ_RESP_INFO => request_response_info.read_value(prop_src)?,
135                pt::RECEIVE_MAX => receive_max.read_value(prop_src)?,
136                pt::TOPIC_ALIAS_MAX => topic_alias_max.read_value(prop_src)?,
137                pt::USER => user_properties.push(UserProperty::decode(prop_src)?),
138                pt::MAX_PACKET_SIZE => max_packet_size.read_value(prop_src)?,
139                _ => return Err(DecodeError::MalformedPacket),
140            }
141        }
142
143        let client_id = ByteString::decode(src)?;
144
145        let last_will = if flags.contains(ConnectFlags::WILL) {
146            Some(decode_last_will(src, flags)?)
147        } else {
148            None
149        };
150
151        let username = if flags.contains(ConnectFlags::USERNAME) {
152            Some(ByteString::decode(src)?)
153        } else {
154            None
155        };
156        let password = if flags.contains(ConnectFlags::PASSWORD) {
157            Some(Bytes::decode(src)?)
158        } else {
159            None
160        };
161
162        Ok(Connect {
163            clean_start: flags.contains(ConnectFlags::CLEAN_START),
164            keep_alive,
165            session_expiry_interval_secs: session_expiry_interval_secs.unwrap_or(0),
166            auth_method,
167            auth_data,
168            receive_max,
169            topic_alias_max: topic_alias_max.unwrap_or(0u16),
170            request_problem_info: request_problem_info.unwrap_or(true),
171            request_response_info: request_response_info.unwrap_or(false),
172            user_properties,
173            max_packet_size,
174
175            client_id,
176            last_will,
177            username,
178            password,
179        })
180    }
181}
182
183impl Default for Connect {
184    fn default() -> Connect {
185        Connect {
186            clean_start: false,
187            keep_alive: 0,
188            session_expiry_interval_secs: 0,
189            auth_method: None,
190            auth_data: None,
191            request_problem_info: true,
192            request_response_info: false,
193            receive_max: None,
194            topic_alias_max: 0,
195            user_properties: Vec::new(),
196            max_packet_size: None,
197            last_will: None,
198            client_id: ByteString::default(),
199            username: None,
200            password: None,
201        }
202    }
203}
204
205fn decode_last_will(src: &mut Bytes, flags: ConnectFlags) -> Result<LastWill, DecodeError> {
206    let mut will_delay_interval_sec = None;
207    let mut correlation_data = None;
208    let mut message_expiry_interval = None;
209    let mut content_type = None;
210    let mut user_properties = Vec::new();
211    let mut is_utf8_payload = None;
212    let mut response_topic = None;
213    let prop_src = &mut utils::take_properties(src)?;
214    while prop_src.has_remaining() {
215        match prop_src.get_u8() {
216            pt::WILL_DELAY_INT => will_delay_interval_sec.read_value(prop_src)?,
217            pt::CORR_DATA => correlation_data.read_value(prop_src)?,
218            pt::MSG_EXPIRY_INT => message_expiry_interval.read_value(prop_src)?,
219            pt::CONTENT_TYPE => content_type.read_value(prop_src)?,
220            pt::UTF8_PAYLOAD => is_utf8_payload.read_value(prop_src)?,
221            pt::RESP_TOPIC => response_topic.read_value(prop_src)?,
222            pt::USER => user_properties.push(UserProperty::decode(prop_src)?),
223            _ => return Err(DecodeError::MalformedPacket),
224        }
225    }
226
227    let topic = ByteString::decode(src)?;
228    let message = Bytes::decode(src)?;
229    Ok(LastWill {
230        qos: QoS::try_from((flags & ConnectFlags::WILL_QOS).bits() >> WILL_QOS_SHIFT)?,
231        retain: flags.contains(ConnectFlags::WILL_RETAIN),
232        topic,
233        message,
234        will_delay_interval_sec,
235        correlation_data,
236        message_expiry_interval,
237        content_type,
238        user_properties,
239        is_utf8_payload,
240        response_topic,
241    })
242}
243
244impl EncodeLtd for Connect {
245    fn encoded_size(&self, _limit: u32) -> usize {
246        let prop_len = self.properties_len();
247        6 // protocol name
248            + 1 // protocol level
249            + 1 // connect flags
250            + 2 // keep alive
251            + var_int_len(prop_len) as usize // properties len
252            + prop_len // properties
253            + self.client_id.encoded_size()
254            + self.last_will.as_ref().map_or(0, |will| { // will message content
255                let prop_len = will.properties_len();
256                var_int_len(prop_len) as usize + prop_len + will.topic.encoded_size() + will.message.encoded_size()
257            })
258            + self.username.as_ref().map_or(0, |v| v.encoded_size())
259            + self.password.as_ref().map_or(0, |v| v.encoded_size())
260    }
261
262    fn encode(&self, buf: &mut BytesMut, _size: u32) -> Result<(), EncodeError> {
263        b"MQTT".as_ref().encode(buf)?;
264
265        let mut flags = ConnectFlags::empty();
266
267        if self.username.is_some() {
268            flags |= ConnectFlags::USERNAME;
269        }
270        if self.password.is_some() {
271            flags |= ConnectFlags::PASSWORD;
272        }
273
274        if let Some(will) = self.last_will.as_ref() {
275            flags |= ConnectFlags::WILL;
276
277            if will.retain {
278                flags |= ConnectFlags::WILL_RETAIN;
279            }
280
281            flags |= ConnectFlags::from_bits_truncate(u8::from(will.qos) << WILL_QOS_SHIFT);
282        }
283
284        if self.clean_start {
285            flags |= ConnectFlags::CLEAN_START;
286        }
287
288        buf.put_slice(&[MQTT_LEVEL_5, flags.bits()]);
289
290        self.keep_alive.encode(buf)?;
291
292        let prop_len = self.properties_len();
293        utils::write_variable_length(prop_len as u32, buf); // safe: whole message size is vetted via max size check in codec
294
295        encode_property_default(
296            &self.session_expiry_interval_secs,
297            0,
298            pt::SESS_EXPIRY_INT,
299            buf,
300        )?;
301        encode_property(&self.auth_method, pt::AUTH_METHOD, buf)?;
302        encode_property(&self.auth_data, pt::AUTH_DATA, buf)?;
303        encode_property_default(&self.request_problem_info, true, pt::REQ_PROB_INFO, buf)?; // 3.1.2.11.7 Request Problem Information
304        encode_property_default(&self.request_response_info, false, pt::REQ_RESP_INFO, buf)?; // 3.1.2.11.6 Request Response Information
305        encode_property(&self.receive_max, pt::RECEIVE_MAX, buf)?;
306        encode_property(&self.max_packet_size, pt::MAX_PACKET_SIZE, buf)?;
307        encode_property_default(&self.topic_alias_max, 0, pt::TOPIC_ALIAS_MAX, buf)?;
308        self.user_properties.encode(buf)?;
309
310        self.client_id.encode(buf)?;
311
312        if let Some(will) = self.last_will.as_ref() {
313            let prop_len = will.properties_len();
314            utils::write_variable_length(prop_len as u32, buf); // safe: whole message size is checked for max already
315
316            encode_property(&will.will_delay_interval_sec, pt::WILL_DELAY_INT, buf)?;
317            encode_property(&will.is_utf8_payload, pt::UTF8_PAYLOAD, buf)?;
318            encode_property(&will.message_expiry_interval, pt::MSG_EXPIRY_INT, buf)?;
319            encode_property(&will.content_type, pt::CONTENT_TYPE, buf)?;
320            encode_property(&will.response_topic, pt::RESP_TOPIC, buf)?;
321            encode_property(&will.correlation_data, pt::CORR_DATA, buf)?;
322            will.user_properties.encode(buf)?;
323            will.topic.encode(buf)?;
324            will.message.encode(buf)?;
325        }
326        if let Some(s) = self.username.as_ref() {
327            s.encode(buf)?;
328        }
329        if let Some(pwd) = self.password.as_ref() {
330            pwd.encode(buf)?;
331        }
332        Ok(())
333    }
334}