mqtt_proto/v3/
connect.rs

1use core::convert::TryFrom;
2
3use alloc::string::String;
4use alloc::sync::Arc;
5
6use bytes::Bytes;
7
8use crate::{
9    from_read_exact_error, read_bytes, read_string, read_u16, read_u8, write_bytes, write_string,
10    write_u16, write_u8, AsyncRead, Encodable, Error, Protocol, QoS, SyncWrite, TopicName,
11};
12
13/// Connect packet body type.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct Connect {
16    pub protocol: Protocol,
17    pub clean_session: bool,
18    pub keep_alive: u16,
19    pub client_id: Arc<String>,
20    pub last_will: Option<LastWill>,
21    pub username: Option<Arc<String>>,
22    pub password: Option<Bytes>,
23}
24
25#[cfg(feature = "arbitrary")]
26impl<'a> arbitrary::Arbitrary<'a> for Connect {
27    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
28        Ok(Connect {
29            protocol: u.arbitrary()?,
30            clean_session: u.arbitrary()?,
31            keep_alive: u.arbitrary()?,
32            client_id: u.arbitrary()?,
33            last_will: u.arbitrary()?,
34            username: u.arbitrary()?,
35            password: Option::<Vec<u8>>::arbitrary(u)?.map(Bytes::from),
36        })
37    }
38}
39
40impl Connect {
41    pub fn new(client_id: Arc<String>, keep_alive: u16) -> Self {
42        Connect {
43            protocol: Protocol::V311,
44            clean_session: true,
45            keep_alive,
46            client_id,
47            last_will: None,
48            username: None,
49            password: None,
50        }
51    }
52
53    pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
54        let protocol = Protocol::decode_async(reader).await?;
55        Self::decode_with_protocol(reader, protocol).await
56    }
57
58    #[inline]
59    pub async fn decode_with_protocol<T: AsyncRead + Unpin>(
60        reader: &mut T,
61        protocol: Protocol,
62    ) -> Result<Self, Error> {
63        if protocol as u8 > 4 {
64            return Err(Error::UnexpectedProtocol(protocol));
65        }
66        let connect_flags: u8 = read_u8(reader).await?;
67        if connect_flags & 1 != 0 {
68            return Err(Error::InvalidConnectFlags(connect_flags));
69        }
70        let keep_alive = read_u16(reader).await?;
71        let client_id = Arc::new(read_string(reader).await?);
72        let last_will = if connect_flags & 0b100 != 0 {
73            let topic_name = read_string(reader).await?;
74            let message = read_bytes(reader).await?;
75            let qos = QoS::from_u8((connect_flags & 0b11000) >> 3)?;
76            let retain = (connect_flags & 0b00100000) != 0;
77            Some(LastWill {
78                topic_name: TopicName::try_from(topic_name)?,
79                message: Bytes::from(message),
80                qos,
81                retain,
82            })
83        } else if connect_flags & 0b11000 != 0 {
84            return Err(Error::InvalidConnectFlags(connect_flags));
85        } else {
86            None
87        };
88        let username = if connect_flags & 0b10000000 != 0 {
89            Some(Arc::new(read_string(reader).await?))
90        } else {
91            None
92        };
93        let password = if connect_flags & 0b01000000 != 0 {
94            Some(Bytes::from(read_bytes(reader).await?))
95        } else {
96            None
97        };
98        let clean_session = (connect_flags & 0b10) != 0;
99        Ok(Connect {
100            protocol,
101            keep_alive,
102            client_id,
103            username,
104            password,
105            last_will,
106            clean_session,
107        })
108    }
109}
110
111impl Encodable for Connect {
112    fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
113        let mut connect_flags: u8 = 0b00000000;
114        if self.clean_session {
115            connect_flags |= 0b10;
116        }
117        if self.username.is_some() {
118            connect_flags |= 0b10000000;
119        }
120        if self.password.is_some() {
121            connect_flags |= 0b01000000;
122        }
123        if let Some(last_will) = self.last_will.as_ref() {
124            connect_flags |= 0b00000100;
125            connect_flags |= (last_will.qos as u8) << 3;
126            if last_will.retain {
127                connect_flags |= 0b00100000;
128            }
129        }
130
131        self.protocol.encode(writer)?;
132        write_u8(writer, connect_flags)?;
133        write_u16(writer, self.keep_alive)?;
134        write_string(writer, &self.client_id)?;
135        if let Some(last_will) = self.last_will.as_ref() {
136            last_will.encode(writer)?;
137        }
138        if let Some(username) = self.username.as_ref() {
139            write_string(writer, username)?;
140        }
141        if let Some(password) = self.password.as_ref() {
142            write_bytes(writer, password.as_ref())?;
143        }
144        Ok(())
145    }
146
147    fn encode_len(&self) -> usize {
148        let mut length = self.protocol.encode_len();
149        // flags + keep-alive
150        length += 1 + 2;
151        // client identifier
152        length += 2 + self.client_id.len();
153        if let Some(last_will) = self.last_will.as_ref() {
154            length += last_will.encode_len();
155        }
156        if let Some(username) = self.username.as_ref() {
157            length += 2 + username.len();
158        }
159        if let Some(password) = self.password.as_ref() {
160            length += 2 + password.len();
161        }
162        length
163    }
164}
165
166/// Connack packet body type.
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
169pub struct Connack {
170    pub session_present: bool,
171    pub code: ConnectReturnCode,
172}
173
174impl Connack {
175    pub fn new(session_present: bool, code: ConnectReturnCode) -> Self {
176        Connack {
177            session_present,
178            code,
179        }
180    }
181
182    pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
183        let mut payload = [0u8; 2];
184        reader
185            .read_exact(&mut payload)
186            .await
187            .map_err(from_read_exact_error)?;
188        let session_present = match payload[0] {
189            0 => false,
190            1 => true,
191            _ => return Err(Error::InvalidConnackFlags(payload[0])),
192        };
193        let code = ConnectReturnCode::from_u8(payload[1])?;
194        Ok(Connack {
195            session_present,
196            code,
197        })
198    }
199}
200
201/// Message that the server should publish when the client disconnects.
202///
203/// Sent by the client in the [Connect] packet. [MQTT 3.1.3.3].
204///
205/// [Connect]: struct.Connect.html
206/// [MQTT 3.1.3.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718031
207#[derive(Debug, Clone, PartialEq, Eq)]
208pub struct LastWill {
209    pub qos: QoS,
210    pub retain: bool,
211    pub topic_name: TopicName,
212    pub message: Bytes,
213}
214
215#[cfg(feature = "arbitrary")]
216impl<'a> arbitrary::Arbitrary<'a> for LastWill {
217    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
218        Ok(LastWill {
219            qos: u.arbitrary()?,
220            retain: u.arbitrary()?,
221            topic_name: u.arbitrary()?,
222            message: Bytes::from(Vec::<u8>::arbitrary(u)?),
223        })
224    }
225}
226
227impl LastWill {
228    pub fn new(qos: QoS, topic_name: TopicName, message: Bytes) -> Self {
229        LastWill {
230            qos,
231            retain: false,
232            topic_name,
233            message,
234        }
235    }
236}
237
238impl Encodable for LastWill {
239    fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
240        write_string(writer, &self.topic_name)?;
241        write_bytes(writer, self.message.as_ref())?;
242        Ok(())
243    }
244
245    fn encode_len(&self) -> usize {
246        4 + self.topic_name.len() + self.message.len()
247    }
248}
249
250/// Return code of a [Connack] packet.
251///
252/// See [MQTT 3.2.2.3] for interpretations.
253///
254/// [Connack]: struct.Connack.html
255/// [MQTT 3.2.2.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718035
256#[repr(u8)]
257#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
258#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
259pub enum ConnectReturnCode {
260    Accepted = 0,
261    UnacceptableProtocolVersion = 1,
262    IdentifierRejected = 2,
263    ServerUnavailable = 3,
264    BadUserNameOrPassword = 4,
265    NotAuthorized = 5,
266}
267
268impl ConnectReturnCode {
269    pub fn from_u8(byte: u8) -> Result<ConnectReturnCode, Error> {
270        match byte {
271            0 => Ok(ConnectReturnCode::Accepted),
272            1 => Ok(ConnectReturnCode::UnacceptableProtocolVersion),
273            2 => Ok(ConnectReturnCode::IdentifierRejected),
274            3 => Ok(ConnectReturnCode::ServerUnavailable),
275            4 => Ok(ConnectReturnCode::BadUserNameOrPassword),
276            5 => Ok(ConnectReturnCode::NotAuthorized),
277            n => Err(Error::InvalidConnectReturnCode(n)),
278        }
279    }
280}