1use std::convert::TryFrom;
2use std::io;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use tokio::io::{AsyncRead, AsyncReadExt};
7
8use crate::{
9 read_bytes, read_string, read_u16, read_u8, write_bytes, write_u16, write_u8, Encodable, Error,
10 Protocol, QoS, TopicName,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct Connect {
16 pub protocol: Protocol,
17 pub clean_session: bool,
18 pub keep_alive: u16,
19 pub client_id: Arc<String>,
20 pub last_will: Option<LastWill>,
21 pub username: Option<Arc<String>>,
22 pub password: Option<Bytes>,
23}
24
25#[cfg(feature = "arbitrary")]
26impl<'a> arbitrary::Arbitrary<'a> for Connect {
27 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
28 Ok(Connect {
29 protocol: u.arbitrary()?,
30 clean_session: u.arbitrary()?,
31 keep_alive: u.arbitrary()?,
32 client_id: u.arbitrary()?,
33 last_will: u.arbitrary()?,
34 username: u.arbitrary()?,
35 password: Option::<Vec<u8>>::arbitrary(u)?.map(Bytes::from),
36 })
37 }
38}
39
40impl Connect {
41 pub fn new(client_id: Arc<String>, keep_alive: u16) -> Self {
42 Connect {
43 protocol: Protocol::V311,
44 clean_session: true,
45 keep_alive,
46 client_id,
47 last_will: None,
48 username: None,
49 password: None,
50 }
51 }
52
53 pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
54 let protocol = Protocol::decode_async(reader).await?;
55 Self::decode_with_protocol(reader, protocol).await
56 }
57
58 #[inline]
59 pub async fn decode_with_protocol<T: AsyncRead + Unpin>(
60 reader: &mut T,
61 protocol: Protocol,
62 ) -> Result<Self, Error> {
63 if protocol as u8 > 4 {
64 return Err(Error::UnexpectedProtocol(protocol));
65 }
66 let connect_flags: u8 = read_u8(reader).await?;
67 if connect_flags & 1 != 0 {
68 return Err(Error::InvalidConnectFlags(connect_flags));
69 }
70 let keep_alive = read_u16(reader).await?;
71 let client_id = Arc::new(read_string(reader).await?);
72 let last_will = if connect_flags & 0b100 != 0 {
73 let topic_name = read_string(reader).await?;
74 let message = read_bytes(reader).await?;
75 let qos = QoS::from_u8((connect_flags & 0b11000) >> 3)?;
76 let retain = (connect_flags & 0b00100000) != 0;
77 Some(LastWill {
78 topic_name: TopicName::try_from(topic_name)?,
79 message: Bytes::from(message),
80 qos,
81 retain,
82 })
83 } else if connect_flags & 0b11000 != 0 {
84 return Err(Error::InvalidConnectFlags(connect_flags));
85 } else {
86 None
87 };
88 let username = if connect_flags & 0b10000000 != 0 {
89 Some(Arc::new(read_string(reader).await?))
90 } else {
91 None
92 };
93 let password = if connect_flags & 0b01000000 != 0 {
94 Some(Bytes::from(read_bytes(reader).await?))
95 } else {
96 None
97 };
98 let clean_session = (connect_flags & 0b10) != 0;
99 Ok(Connect {
100 protocol,
101 keep_alive,
102 client_id,
103 username,
104 password,
105 last_will,
106 clean_session,
107 })
108 }
109}
110
111impl Encodable for Connect {
112 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
113 let mut connect_flags: u8 = 0b00000000;
114 if self.clean_session {
115 connect_flags |= 0b10;
116 }
117 if self.username.is_some() {
118 connect_flags |= 0b10000000;
119 }
120 if self.password.is_some() {
121 connect_flags |= 0b01000000;
122 }
123 if let Some(last_will) = self.last_will.as_ref() {
124 connect_flags |= 0b00000100;
125 connect_flags |= (last_will.qos as u8) << 3;
126 if last_will.retain {
127 connect_flags |= 0b00100000;
128 }
129 }
130
131 self.protocol.encode(writer)?;
132 write_u8(writer, connect_flags)?;
133 write_u16(writer, self.keep_alive)?;
134 write_bytes(writer, self.client_id.as_bytes())?;
135 if let Some(last_will) = self.last_will.as_ref() {
136 last_will.encode(writer)?;
137 }
138 if let Some(username) = self.username.as_ref() {
139 write_bytes(writer, username.as_bytes())?;
140 }
141 if let Some(password) = self.password.as_ref() {
142 write_bytes(writer, password.as_ref())?;
143 }
144 Ok(())
145 }
146
147 fn encode_len(&self) -> usize {
148 let mut length = self.protocol.encode_len();
149 length += 1 + 2;
151 length += 2 + self.client_id.len();
153 if let Some(last_will) = self.last_will.as_ref() {
154 length += last_will.encode_len();
155 }
156 if let Some(username) = self.username.as_ref() {
157 length += 2 + username.len();
158 }
159 if let Some(password) = self.password.as_ref() {
160 length += 2 + password.len();
161 }
162 length
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
169pub struct Connack {
170 pub session_present: bool,
171 pub code: ConnectReturnCode,
172}
173
174impl Connack {
175 pub fn new(session_present: bool, code: ConnectReturnCode) -> Self {
176 Connack {
177 session_present,
178 code,
179 }
180 }
181
182 pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
183 let mut payload = [0u8; 2];
184 reader.read_exact(&mut payload).await?;
185 let session_present = match payload[0] {
186 0 => false,
187 1 => true,
188 _ => return Err(Error::InvalidConnackFlags(payload[0])),
189 };
190 let code = ConnectReturnCode::from_u8(payload[1])?;
191 Ok(Connack {
192 session_present,
193 code,
194 })
195 }
196}
197
198#[derive(Debug, Clone, PartialEq, Eq)]
205pub struct LastWill {
206 pub qos: QoS,
207 pub retain: bool,
208 pub topic_name: TopicName,
209 pub message: Bytes,
210}
211
212#[cfg(feature = "arbitrary")]
213impl<'a> arbitrary::Arbitrary<'a> for LastWill {
214 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
215 Ok(LastWill {
216 qos: u.arbitrary()?,
217 retain: u.arbitrary()?,
218 topic_name: u.arbitrary()?,
219 message: Bytes::from(Vec::<u8>::arbitrary(u)?),
220 })
221 }
222}
223
224impl LastWill {
225 pub fn new(qos: QoS, topic_name: TopicName, message: Bytes) -> Self {
226 LastWill {
227 qos,
228 retain: false,
229 topic_name,
230 message,
231 }
232 }
233}
234
235impl Encodable for LastWill {
236 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
237 write_bytes(writer, self.topic_name.as_bytes())?;
238 write_bytes(writer, self.message.as_ref())?;
239 Ok(())
240 }
241
242 fn encode_len(&self) -> usize {
243 4 + self.topic_name.len() + self.message.len()
244 }
245}
246
247#[repr(u8)]
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
255#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
256pub enum ConnectReturnCode {
257 Accepted = 0,
258 UnacceptableProtocolVersion = 1,
259 IdentifierRejected = 2,
260 ServerUnavailable = 3,
261 BadUserNameOrPassword = 4,
262 NotAuthorized = 5,
263}
264
265impl ConnectReturnCode {
266 pub fn from_u8(byte: u8) -> Result<ConnectReturnCode, Error> {
267 match byte {
268 0 => Ok(ConnectReturnCode::Accepted),
269 1 => Ok(ConnectReturnCode::UnacceptableProtocolVersion),
270 2 => Ok(ConnectReturnCode::IdentifierRejected),
271 3 => Ok(ConnectReturnCode::ServerUnavailable),
272 4 => Ok(ConnectReturnCode::BadUserNameOrPassword),
273 5 => Ok(ConnectReturnCode::NotAuthorized),
274 n => Err(Error::InvalidConnectReturnCode(n)),
275 }
276 }
277}