1use core::convert::TryFrom;
2
3use alloc::string::String;
4use alloc::sync::Arc;
5
6use bytes::Bytes;
7
8use crate::{
9 from_read_exact_error, read_bytes, read_string, read_u16, read_u8, write_bytes, write_string,
10 write_u16, write_u8, AsyncRead, Encodable, Error, Protocol, QoS, SyncWrite, TopicName,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct Connect {
16 pub protocol: Protocol,
17 pub clean_session: bool,
18 pub keep_alive: u16,
19 pub client_id: Arc<String>,
20 pub last_will: Option<LastWill>,
21 pub username: Option<Arc<String>>,
22 pub password: Option<Bytes>,
23}
24
25#[cfg(feature = "arbitrary")]
26impl<'a> arbitrary::Arbitrary<'a> for Connect {
27 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
28 Ok(Connect {
29 protocol: u.arbitrary()?,
30 clean_session: u.arbitrary()?,
31 keep_alive: u.arbitrary()?,
32 client_id: u.arbitrary()?,
33 last_will: u.arbitrary()?,
34 username: u.arbitrary()?,
35 password: Option::<Vec<u8>>::arbitrary(u)?.map(Bytes::from),
36 })
37 }
38}
39
40impl Connect {
41 pub fn new(client_id: Arc<String>, keep_alive: u16) -> Self {
42 Connect {
43 protocol: Protocol::V311,
44 clean_session: true,
45 keep_alive,
46 client_id,
47 last_will: None,
48 username: None,
49 password: None,
50 }
51 }
52
53 pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
54 let protocol = Protocol::decode_async(reader).await?;
55 Self::decode_with_protocol(reader, protocol).await
56 }
57
58 #[inline]
59 pub async fn decode_with_protocol<T: AsyncRead + Unpin>(
60 reader: &mut T,
61 protocol: Protocol,
62 ) -> Result<Self, Error> {
63 if protocol as u8 > 4 {
64 return Err(Error::UnexpectedProtocol(protocol));
65 }
66 let connect_flags: u8 = read_u8(reader).await?;
67 if connect_flags & 1 != 0 {
68 return Err(Error::InvalidConnectFlags(connect_flags));
69 }
70 let keep_alive = read_u16(reader).await?;
71 let client_id = Arc::new(read_string(reader).await?);
72 let last_will = if connect_flags & 0b100 != 0 {
73 let topic_name = read_string(reader).await?;
74 let message = read_bytes(reader).await?;
75 let qos = QoS::from_u8((connect_flags & 0b11000) >> 3)?;
76 let retain = (connect_flags & 0b00100000) != 0;
77 Some(LastWill {
78 topic_name: TopicName::try_from(topic_name)?,
79 message: Bytes::from(message),
80 qos,
81 retain,
82 })
83 } else if connect_flags & 0b11000 != 0 {
84 return Err(Error::InvalidConnectFlags(connect_flags));
85 } else {
86 None
87 };
88 let username = if connect_flags & 0b10000000 != 0 {
89 Some(Arc::new(read_string(reader).await?))
90 } else {
91 None
92 };
93 let password = if connect_flags & 0b01000000 != 0 {
94 Some(Bytes::from(read_bytes(reader).await?))
95 } else {
96 None
97 };
98 let clean_session = (connect_flags & 0b10) != 0;
99 Ok(Connect {
100 protocol,
101 keep_alive,
102 client_id,
103 username,
104 password,
105 last_will,
106 clean_session,
107 })
108 }
109}
110
111impl Encodable for Connect {
112 fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
113 let mut connect_flags: u8 = 0b00000000;
114 if self.clean_session {
115 connect_flags |= 0b10;
116 }
117 if self.username.is_some() {
118 connect_flags |= 0b10000000;
119 }
120 if self.password.is_some() {
121 connect_flags |= 0b01000000;
122 }
123 if let Some(last_will) = self.last_will.as_ref() {
124 connect_flags |= 0b00000100;
125 connect_flags |= (last_will.qos as u8) << 3;
126 if last_will.retain {
127 connect_flags |= 0b00100000;
128 }
129 }
130
131 self.protocol.encode(writer)?;
132 write_u8(writer, connect_flags)?;
133 write_u16(writer, self.keep_alive)?;
134 write_string(writer, &self.client_id)?;
135 if let Some(last_will) = self.last_will.as_ref() {
136 last_will.encode(writer)?;
137 }
138 if let Some(username) = self.username.as_ref() {
139 write_string(writer, username)?;
140 }
141 if let Some(password) = self.password.as_ref() {
142 write_bytes(writer, password.as_ref())?;
143 }
144 Ok(())
145 }
146
147 fn encode_len(&self) -> usize {
148 let mut length = self.protocol.encode_len();
149 length += 1 + 2;
151 length += 2 + self.client_id.len();
153 if let Some(last_will) = self.last_will.as_ref() {
154 length += last_will.encode_len();
155 }
156 if let Some(username) = self.username.as_ref() {
157 length += 2 + username.len();
158 }
159 if let Some(password) = self.password.as_ref() {
160 length += 2 + password.len();
161 }
162 length
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
169pub struct Connack {
170 pub session_present: bool,
171 pub code: ConnectReturnCode,
172}
173
174impl Connack {
175 pub fn new(session_present: bool, code: ConnectReturnCode) -> Self {
176 Connack {
177 session_present,
178 code,
179 }
180 }
181
182 pub async fn decode_async<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Self, Error> {
183 let mut payload = [0u8; 2];
184 reader
185 .read_exact(&mut payload)
186 .await
187 .map_err(from_read_exact_error)?;
188 let session_present = match payload[0] {
189 0 => false,
190 1 => true,
191 _ => return Err(Error::InvalidConnackFlags(payload[0])),
192 };
193 let code = ConnectReturnCode::from_u8(payload[1])?;
194 Ok(Connack {
195 session_present,
196 code,
197 })
198 }
199}
200
201#[derive(Debug, Clone, PartialEq, Eq)]
208pub struct LastWill {
209 pub qos: QoS,
210 pub retain: bool,
211 pub topic_name: TopicName,
212 pub message: Bytes,
213}
214
215#[cfg(feature = "arbitrary")]
216impl<'a> arbitrary::Arbitrary<'a> for LastWill {
217 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
218 Ok(LastWill {
219 qos: u.arbitrary()?,
220 retain: u.arbitrary()?,
221 topic_name: u.arbitrary()?,
222 message: Bytes::from(Vec::<u8>::arbitrary(u)?),
223 })
224 }
225}
226
227impl LastWill {
228 pub fn new(qos: QoS, topic_name: TopicName, message: Bytes) -> Self {
229 LastWill {
230 qos,
231 retain: false,
232 topic_name,
233 message,
234 }
235 }
236}
237
238impl Encodable for LastWill {
239 fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
240 write_string(writer, &self.topic_name)?;
241 write_bytes(writer, self.message.as_ref())?;
242 Ok(())
243 }
244
245 fn encode_len(&self) -> usize {
246 4 + self.topic_name.len() + self.message.len()
247 }
248}
249
250#[repr(u8)]
257#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
258#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
259pub enum ConnectReturnCode {
260 Accepted = 0,
261 UnacceptableProtocolVersion = 1,
262 IdentifierRejected = 2,
263 ServerUnavailable = 3,
264 BadUserNameOrPassword = 4,
265 NotAuthorized = 5,
266}
267
268impl ConnectReturnCode {
269 pub fn from_u8(byte: u8) -> Result<ConnectReturnCode, Error> {
270 match byte {
271 0 => Ok(ConnectReturnCode::Accepted),
272 1 => Ok(ConnectReturnCode::UnacceptableProtocolVersion),
273 2 => Ok(ConnectReturnCode::IdentifierRejected),
274 3 => Ok(ConnectReturnCode::ServerUnavailable),
275 4 => Ok(ConnectReturnCode::BadUserNameOrPassword),
276 5 => Ok(ConnectReturnCode::NotAuthorized),
277 n => Err(Error::InvalidConnectReturnCode(n)),
278 }
279 }
280}