mqtt_proto/v3/
connect.rs

1use std::convert::TryFrom;
2use std::io;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use tokio::io::{AsyncRead, AsyncReadExt};
7
8use crate::{
9    read_bytes, read_string, read_u16, read_u8, write_bytes, write_u16, write_u8, Encodable, Error,
10    Protocol, QoS, TopicName,
11};
12
13/// Connect packet body type.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct Connect {
16    pub protocol: Protocol,
17    pub clean_session: bool,
18    pub keep_alive: u16,
19    pub client_id: Arc<String>,
20    pub last_will: Option<LastWill>,
21    pub username: Option<Arc<String>>,
22    pub password: Option<Bytes>,
23}
24
25#[cfg(feature = "arbitrary")]
26impl<'a> arbitrary::Arbitrary<'a> for Connect {
27    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
28        Ok(Connect {
29            protocol: u.arbitrary()?,
30            clean_session: u.arbitrary()?,
31            keep_alive: u.arbitrary()?,
32            client_id: u.arbitrary()?,
33            last_will: u.arbitrary()?,
34            username: u.arbitrary()?,
35            password: Option::<Vec<u8>>::arbitrary(u)?.map(Bytes::from),
36        })
37    }
38}
39
40impl Connect {
41    pub fn new(client_id: Arc<String>, keep_alive: u16) -> Self {
42        Connect {
43            protocol: Protocol::V311,
44            clean_session: true,
45            keep_alive,
46            client_id,
47            last_will: None,
48            username: None,
49            password: None,
50        }
51    }
52
53    pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
54        let protocol = Protocol::decode_async(reader).await?;
55        Self::decode_with_protocol(reader, protocol).await
56    }
57
58    #[inline]
59    pub async fn decode_with_protocol<T: AsyncRead + Unpin>(
60        reader: &mut T,
61        protocol: Protocol,
62    ) -> Result<Self, Error> {
63        if protocol as u8 > 4 {
64            return Err(Error::UnexpectedProtocol(protocol));
65        }
66        let connect_flags: u8 = read_u8(reader).await?;
67        if connect_flags & 1 != 0 {
68            return Err(Error::InvalidConnectFlags(connect_flags));
69        }
70        let keep_alive = read_u16(reader).await?;
71        let client_id = Arc::new(read_string(reader).await?);
72        let last_will = if connect_flags & 0b100 != 0 {
73            let topic_name = read_string(reader).await?;
74            let message = read_bytes(reader).await?;
75            let qos = QoS::from_u8((connect_flags & 0b11000) >> 3)?;
76            let retain = (connect_flags & 0b00100000) != 0;
77            Some(LastWill {
78                topic_name: TopicName::try_from(topic_name)?,
79                message: Bytes::from(message),
80                qos,
81                retain,
82            })
83        } else if connect_flags & 0b11000 != 0 {
84            return Err(Error::InvalidConnectFlags(connect_flags));
85        } else {
86            None
87        };
88        let username = if connect_flags & 0b10000000 != 0 {
89            Some(Arc::new(read_string(reader).await?))
90        } else {
91            None
92        };
93        let password = if connect_flags & 0b01000000 != 0 {
94            Some(Bytes::from(read_bytes(reader).await?))
95        } else {
96            None
97        };
98        let clean_session = (connect_flags & 0b10) != 0;
99        Ok(Connect {
100            protocol,
101            keep_alive,
102            client_id,
103            username,
104            password,
105            last_will,
106            clean_session,
107        })
108    }
109}
110
111impl Encodable for Connect {
112    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
113        let mut connect_flags: u8 = 0b00000000;
114        if self.clean_session {
115            connect_flags |= 0b10;
116        }
117        if self.username.is_some() {
118            connect_flags |= 0b10000000;
119        }
120        if self.password.is_some() {
121            connect_flags |= 0b01000000;
122        }
123        if let Some(last_will) = self.last_will.as_ref() {
124            connect_flags |= 0b00000100;
125            connect_flags |= (last_will.qos as u8) << 3;
126            if last_will.retain {
127                connect_flags |= 0b00100000;
128            }
129        }
130
131        self.protocol.encode(writer)?;
132        write_u8(writer, connect_flags)?;
133        write_u16(writer, self.keep_alive)?;
134        write_bytes(writer, self.client_id.as_bytes())?;
135        if let Some(last_will) = self.last_will.as_ref() {
136            last_will.encode(writer)?;
137        }
138        if let Some(username) = self.username.as_ref() {
139            write_bytes(writer, username.as_bytes())?;
140        }
141        if let Some(password) = self.password.as_ref() {
142            write_bytes(writer, password.as_ref())?;
143        }
144        Ok(())
145    }
146
147    fn encode_len(&self) -> usize {
148        let mut length = self.protocol.encode_len();
149        // flags + keep-alive
150        length += 1 + 2;
151        // client identifier
152        length += 2 + self.client_id.len();
153        if let Some(last_will) = self.last_will.as_ref() {
154            length += last_will.encode_len();
155        }
156        if let Some(username) = self.username.as_ref() {
157            length += 2 + username.len();
158        }
159        if let Some(password) = self.password.as_ref() {
160            length += 2 + password.len();
161        }
162        length
163    }
164}
165
166/// Connack packet body type.
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
169pub struct Connack {
170    pub session_present: bool,
171    pub code: ConnectReturnCode,
172}
173
174impl Connack {
175    pub fn new(session_present: bool, code: ConnectReturnCode) -> Self {
176        Connack {
177            session_present,
178            code,
179        }
180    }
181
182    pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
183        let mut payload = [0u8; 2];
184        reader.read_exact(&mut payload).await?;
185        let session_present = match payload[0] {
186            0 => false,
187            1 => true,
188            _ => return Err(Error::InvalidConnackFlags(payload[0])),
189        };
190        let code = ConnectReturnCode::from_u8(payload[1])?;
191        Ok(Connack {
192            session_present,
193            code,
194        })
195    }
196}
197
198/// Message that the server should publish when the client disconnects.
199///
200/// Sent by the client in the [Connect] packet. [MQTT 3.1.3.3].
201///
202/// [Connect]: struct.Connect.html
203/// [MQTT 3.1.3.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718031
204#[derive(Debug, Clone, PartialEq, Eq)]
205pub struct LastWill {
206    pub qos: QoS,
207    pub retain: bool,
208    pub topic_name: TopicName,
209    pub message: Bytes,
210}
211
212#[cfg(feature = "arbitrary")]
213impl<'a> arbitrary::Arbitrary<'a> for LastWill {
214    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
215        Ok(LastWill {
216            qos: u.arbitrary()?,
217            retain: u.arbitrary()?,
218            topic_name: u.arbitrary()?,
219            message: Bytes::from(Vec::<u8>::arbitrary(u)?),
220        })
221    }
222}
223
224impl LastWill {
225    pub fn new(qos: QoS, topic_name: TopicName, message: Bytes) -> Self {
226        LastWill {
227            qos,
228            retain: false,
229            topic_name,
230            message,
231        }
232    }
233}
234
235impl Encodable for LastWill {
236    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
237        write_bytes(writer, self.topic_name.as_bytes())?;
238        write_bytes(writer, self.message.as_ref())?;
239        Ok(())
240    }
241
242    fn encode_len(&self) -> usize {
243        4 + self.topic_name.len() + self.message.len()
244    }
245}
246
247/// Return code of a [Connack] packet.
248///
249/// See [MQTT 3.2.2.3] for interpretations.
250///
251/// [Connack]: struct.Connack.html
252/// [MQTT 3.2.2.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718035
253#[repr(u8)]
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
255#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
256pub enum ConnectReturnCode {
257    Accepted = 0,
258    UnacceptableProtocolVersion = 1,
259    IdentifierRejected = 2,
260    ServerUnavailable = 3,
261    BadUserNameOrPassword = 4,
262    NotAuthorized = 5,
263}
264
265impl ConnectReturnCode {
266    pub fn from_u8(byte: u8) -> Result<ConnectReturnCode, Error> {
267        match byte {
268            0 => Ok(ConnectReturnCode::Accepted),
269            1 => Ok(ConnectReturnCode::UnacceptableProtocolVersion),
270            2 => Ok(ConnectReturnCode::IdentifierRejected),
271            3 => Ok(ConnectReturnCode::ServerUnavailable),
272            4 => Ok(ConnectReturnCode::BadUserNameOrPassword),
273            5 => Ok(ConnectReturnCode::NotAuthorized),
274            n => Err(Error::InvalidConnectReturnCode(n)),
275        }
276    }
277}