rmqtt_codec/v5/packet/
connect.rs

1use std::num::{NonZeroU16, NonZeroU32};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use bytestring::ByteString;
5use serde::{Deserialize, Serialize};
6
7use crate::error::{DecodeError, EncodeError};
8use crate::types::{ConnectFlags, QoS, MQTT, MQTT_LEVEL_5, WILL_QOS_SHIFT};
9use crate::utils::{self, Decode, Encode, Property};
10use crate::v5::{encode::*, property_type as pt, UserProperties, UserProperty};
11
12#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
13/// Connect packet content
14pub struct Connect {
15    /// the handling of the Session state.
16    pub clean_start: bool,
17    /// a time interval measured in seconds.
18    pub keep_alive: u16,
19
20    pub session_expiry_interval_secs: u32,
21    pub auth_method: Option<ByteString>,
22    pub auth_data: Option<Bytes>,
23    pub request_problem_info: bool,
24    pub request_response_info: bool,
25    pub receive_max: Option<NonZeroU16>,
26    pub topic_alias_max: u16,
27    pub user_properties: UserProperties,
28    pub max_packet_size: Option<NonZeroU32>,
29
30    /// Will Message be stored on the Server and associated with the Network Connection.
31    pub last_will: Option<LastWill>,
32    /// identifies the Client to the Server.
33    pub client_id: ByteString,
34    /// username can be used by the Server for authentication and authorization.
35    pub username: Option<ByteString>,
36    /// password can be used by the Server for authentication and authorization.
37    pub password: Option<Bytes>,
38}
39
40#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
41/// Connection Will
42pub struct LastWill {
43    /// the QoS level to be used when publishing the Will Message.
44    pub qos: QoS,
45    /// the Will Message is to be Retained when it is published.
46    pub retain: bool,
47    /// the Will Topic
48    pub topic: ByteString,
49    /// defines the Application Message that is to be published to the Will Topic
50    pub message: Bytes,
51
52    pub will_delay_interval_sec: Option<u32>,
53    pub correlation_data: Option<Bytes>,
54    pub message_expiry_interval: Option<NonZeroU32>,
55    pub content_type: Option<ByteString>,
56    pub user_properties: UserProperties,
57    pub is_utf8_payload: Option<bool>,
58    pub response_topic: Option<ByteString>,
59}
60
61impl LastWill {
62    fn properties_len(&self) -> usize {
63        encoded_property_size(&self.will_delay_interval_sec)
64            + encoded_property_size(&self.correlation_data)
65            + encoded_property_size(&self.message_expiry_interval)
66            + encoded_property_size(&self.content_type)
67            + encoded_property_size(&self.is_utf8_payload)
68            + encoded_property_size(&self.response_topic)
69            + self.user_properties.encoded_size()
70    }
71}
72
73impl Connect {
74    /// Set client_id value
75    pub fn client_id<T>(mut self, client_id: T) -> Self
76    where
77        ByteString: From<T>,
78    {
79        self.client_id = client_id.into();
80        self
81    }
82
83    /// Set receive_max value
84    pub fn receive_max(mut self, max: u16) -> Self {
85        if let Some(num) = NonZeroU16::new(max) {
86            self.receive_max = Some(num);
87        } else {
88            self.receive_max = None;
89        }
90        self
91    }
92
93    fn properties_len(&self) -> usize {
94        encoded_property_size(&self.auth_method)
95            + encoded_property_size(&self.auth_data)
96            + encoded_property_size_default(&self.session_expiry_interval_secs, 0)
97            + encoded_property_size_default(&self.request_problem_info, true) // 3.1.2.11.7 Request Problem Information
98            + encoded_property_size_default(&self.request_response_info, false) // 3.1.2.11.6 Request Response Information
99            + encoded_property_size(&self.receive_max)
100            + encoded_property_size(&self.max_packet_size)
101            + encoded_property_size_default(&self.topic_alias_max, 0)
102            + self.user_properties.encoded_size()
103    }
104
105    pub(crate) fn decode(src: &mut Bytes) -> Result<Self, DecodeError> {
106        ensure!(src.remaining() >= 10, DecodeError::InvalidLength);
107        let len = src.get_u16();
108
109        ensure!(len == 4 && &src.as_ref()[0..4] == MQTT, DecodeError::InvalidProtocol);
110        src.advance(4);
111
112        let level = src.get_u8();
113        ensure!(level == MQTT_LEVEL_5, DecodeError::UnsupportedProtocolLevel);
114
115        let flags = ConnectFlags::from_bits(src.get_u8()).ok_or(DecodeError::ConnectReservedFlagSet)?;
116        let keep_alive = src.get_u16();
117
118        // reading properties
119        let mut session_expiry_interval_secs = None;
120        let mut auth_method = None;
121        let mut auth_data = None;
122        let mut request_problem_info = None;
123        let mut request_response_info = None;
124        let mut receive_max = None;
125        let mut topic_alias_max = None;
126        let mut user_properties = Vec::new();
127        let mut max_packet_size = None;
128        let prop_src = &mut utils::take_properties(src)?;
129        while prop_src.has_remaining() {
130            match prop_src.get_u8() {
131                pt::SESS_EXPIRY_INT => session_expiry_interval_secs.read_value(prop_src)?,
132                pt::AUTH_METHOD => auth_method.read_value(prop_src)?,
133                pt::AUTH_DATA => auth_data.read_value(prop_src)?,
134                pt::REQ_PROB_INFO => request_problem_info.read_value(prop_src)?,
135                pt::REQ_RESP_INFO => request_response_info.read_value(prop_src)?,
136                pt::RECEIVE_MAX => receive_max.read_value(prop_src)?,
137                pt::TOPIC_ALIAS_MAX => topic_alias_max.read_value(prop_src)?,
138                pt::USER => user_properties.push(UserProperty::decode(prop_src)?),
139                pt::MAX_PACKET_SIZE => max_packet_size.read_value(prop_src)?,
140                _ => return Err(DecodeError::MalformedPacket),
141            }
142        }
143
144        let client_id = ByteString::decode(src)?;
145
146        ensure!(
147            // todo: [MQTT-3.1.3-8]?
148            !client_id.is_empty() || flags.contains(ConnectFlags::CLEAN_START),
149            DecodeError::InvalidClientId
150        );
151
152        let last_will =
153            if flags.contains(ConnectFlags::WILL) { Some(decode_last_will(src, flags)?) } else { None };
154
155        let username =
156            if flags.contains(ConnectFlags::USERNAME) { Some(ByteString::decode(src)?) } else { None };
157        let password = if flags.contains(ConnectFlags::PASSWORD) { Some(Bytes::decode(src)?) } else { None };
158
159        Ok(Connect {
160            clean_start: flags.contains(ConnectFlags::CLEAN_START),
161            keep_alive,
162            session_expiry_interval_secs: session_expiry_interval_secs.unwrap_or(0),
163            auth_method,
164            auth_data,
165            receive_max,
166            topic_alias_max: topic_alias_max.unwrap_or(0u16),
167            request_problem_info: request_problem_info.unwrap_or(true),
168            request_response_info: request_response_info.unwrap_or(false),
169            user_properties,
170            max_packet_size,
171
172            client_id,
173            last_will,
174            username,
175            password,
176        })
177    }
178}
179
180impl Default for Connect {
181    fn default() -> Connect {
182        Connect {
183            clean_start: false,
184            keep_alive: 0,
185            session_expiry_interval_secs: 0,
186            auth_method: None,
187            auth_data: None,
188            request_problem_info: true,
189            request_response_info: false,
190            receive_max: None,
191            topic_alias_max: 0,
192            user_properties: Vec::new(),
193            max_packet_size: None,
194            last_will: None,
195            client_id: ByteString::default(),
196            username: None,
197            password: None,
198        }
199    }
200}
201
202fn decode_last_will(src: &mut Bytes, flags: ConnectFlags) -> Result<LastWill, DecodeError> {
203    let mut will_delay_interval_sec = None;
204    let mut correlation_data = None;
205    let mut message_expiry_interval = None;
206    let mut content_type = None;
207    let mut user_properties = Vec::new();
208    let mut is_utf8_payload = None;
209    let mut response_topic = None;
210    let prop_src = &mut utils::take_properties(src)?;
211    while prop_src.has_remaining() {
212        match prop_src.get_u8() {
213            pt::WILL_DELAY_INT => will_delay_interval_sec.read_value(prop_src)?,
214            pt::CORR_DATA => correlation_data.read_value(prop_src)?,
215            pt::MSG_EXPIRY_INT => message_expiry_interval.read_value(prop_src)?,
216            pt::CONTENT_TYPE => content_type.read_value(prop_src)?,
217            pt::UTF8_PAYLOAD => is_utf8_payload.read_value(prop_src)?,
218            pt::RESP_TOPIC => response_topic.read_value(prop_src)?,
219            pt::USER => user_properties.push(UserProperty::decode(prop_src)?),
220            _ => return Err(DecodeError::MalformedPacket),
221        }
222    }
223
224    let topic = ByteString::decode(src)?;
225    let message = Bytes::decode(src)?;
226    Ok(LastWill {
227        qos: QoS::try_from((flags & ConnectFlags::WILL_QOS).bits() >> WILL_QOS_SHIFT)?,
228        retain: flags.contains(ConnectFlags::WILL_RETAIN),
229        topic,
230        message,
231        will_delay_interval_sec,
232        correlation_data,
233        message_expiry_interval,
234        content_type,
235        user_properties,
236        is_utf8_payload,
237        response_topic,
238    })
239}
240
241impl EncodeLtd for Connect {
242    fn encoded_size(&self, _limit: u32) -> usize {
243        let prop_len = self.properties_len();
244        6 // protocol name
245            + 1 // protocol level
246            + 1 // connect flags
247            + 2 // keep alive
248            + var_int_len(prop_len) as usize // properties len
249            + prop_len // properties
250            + self.client_id.encoded_size()
251            + self.last_will.as_ref().map_or(0, |will| { // will message content
252                let prop_len = will.properties_len();
253                var_int_len(prop_len) as usize + prop_len + will.topic.encoded_size() + will.message.encoded_size()
254            })
255            + self.username.as_ref().map_or(0, |v| v.encoded_size())
256            + self.password.as_ref().map_or(0, |v| v.encoded_size())
257    }
258
259    fn encode(&self, buf: &mut BytesMut, _size: u32) -> Result<(), EncodeError> {
260        b"MQTT".as_ref().encode(buf)?;
261
262        let mut flags = ConnectFlags::empty();
263
264        if self.username.is_some() {
265            flags |= ConnectFlags::USERNAME;
266        }
267        if self.password.is_some() {
268            flags |= ConnectFlags::PASSWORD;
269        }
270
271        if let Some(will) = self.last_will.as_ref() {
272            flags |= ConnectFlags::WILL;
273
274            if will.retain {
275                flags |= ConnectFlags::WILL_RETAIN;
276            }
277
278            flags |= ConnectFlags::from_bits_truncate(u8::from(will.qos) << WILL_QOS_SHIFT);
279        }
280
281        if self.clean_start {
282            flags |= ConnectFlags::CLEAN_START;
283        }
284
285        buf.put_slice(&[MQTT_LEVEL_5, flags.bits()]);
286
287        self.keep_alive.encode(buf)?;
288
289        let prop_len = self.properties_len();
290        utils::write_variable_length(prop_len as u32, buf); // safe: whole message size is vetted via max size check in codec
291
292        encode_property_default(&self.session_expiry_interval_secs, 0, pt::SESS_EXPIRY_INT, buf)?;
293        encode_property(&self.auth_method, pt::AUTH_METHOD, buf)?;
294        encode_property(&self.auth_data, pt::AUTH_DATA, buf)?;
295        encode_property_default(&self.request_problem_info, true, pt::REQ_PROB_INFO, buf)?; // 3.1.2.11.7 Request Problem Information
296        encode_property_default(&self.request_response_info, false, pt::REQ_RESP_INFO, buf)?; // 3.1.2.11.6 Request Response Information
297        encode_property(&self.receive_max, pt::RECEIVE_MAX, buf)?;
298        encode_property(&self.max_packet_size, pt::MAX_PACKET_SIZE, buf)?;
299        encode_property_default(&self.topic_alias_max, 0, pt::TOPIC_ALIAS_MAX, buf)?;
300        self.user_properties.encode(buf)?;
301
302        self.client_id.encode(buf)?;
303
304        if let Some(will) = self.last_will.as_ref() {
305            let prop_len = will.properties_len();
306            utils::write_variable_length(prop_len as u32, buf); // safe: whole message size is checked for max already
307
308            encode_property(&will.will_delay_interval_sec, pt::WILL_DELAY_INT, buf)?;
309            encode_property(&will.is_utf8_payload, pt::UTF8_PAYLOAD, buf)?;
310            encode_property(&will.message_expiry_interval, pt::MSG_EXPIRY_INT, buf)?;
311            encode_property(&will.content_type, pt::CONTENT_TYPE, buf)?;
312            encode_property(&will.response_topic, pt::RESP_TOPIC, buf)?;
313            encode_property(&will.correlation_data, pt::CORR_DATA, buf)?;
314            will.user_properties.encode(buf)?;
315            will.topic.encode(buf)?;
316            will.message.encode(buf)?;
317        }
318        if let Some(s) = self.username.as_ref() {
319            s.encode(buf)?;
320        }
321        if let Some(pwd) = self.password.as_ref() {
322            pwd.encode(buf)?;
323        }
324        Ok(())
325    }
326}