1#![forbid(unsafe_code)]
2pub const TOPIC_SEPARATOR: char = '/';
3
4pub const MULTI_LEVEL_WILDCARD: char = '#';
5pub const MULTI_LEVEL_WILDCARD_STR: &str = "#";
6
7pub const SINGLE_LEVEL_WILDCARD: char = '+';
8pub const SINGLE_LEVEL_WILDCARD_STR: &str = "+";
9
10pub const SHARED_SUBSCRIPTION_PREFIX: &str = "$share/";
11
12pub const MAX_TOPIC_LEN_BYTES: usize = 65_535;
13
14pub mod decoder;
15pub mod encoder;
16pub mod topic;
17pub mod types;
18
19#[cfg(feature = "codec")]
20pub mod codec {
21 use crate::{
22 decoder, encoder,
23 types::{DecodeError, EncodeError, Packet, ProtocolVersion},
24 };
25 use bytes::BytesMut;
26 use tokio_util::codec::{Decoder, Encoder};
27
28 pub struct MqttCodec {
29 version: ProtocolVersion,
30 }
31
32 impl Default for MqttCodec {
33 fn default() -> Self {
34 MqttCodec::new()
35 }
36 }
37
38 impl MqttCodec {
39 pub fn new() -> Self {
40 MqttCodec { version: ProtocolVersion::V311 }
41 }
42
43 pub fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Packet>, DecodeError> {
44 let packet = decoder::decode_mqtt(buf, self.version);
46
47 if let Ok(Some(Packet::Connect(packet))) = &packet {
48 self.version = packet.protocol_version;
49 }
50
51 packet
52 }
53
54 pub fn encode(&mut self, packet: Packet, bytes: &mut BytesMut) -> Result<(), EncodeError> {
55 encoder::encode_mqtt(&packet, bytes, self.version);
56 Ok(())
57 }
58 }
59
60 impl Decoder for MqttCodec {
61 type Error = DecodeError;
62 type Item = Packet;
63
64 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
65 self.decode(buf)
67 }
68 }
69
70 impl Encoder<Packet> for MqttCodec {
71 type Error = EncodeError;
72
73 fn encode(&mut self, packet: Packet, bytes: &mut BytesMut) -> Result<(), Self::Error> {
74 self.encode(packet, bytes)
75 }
76 }
77}
78
79#[cfg(feature = "websocket")]
80pub mod websocket {
81 use bytes::BytesMut;
82 use tokio_util::codec::{Decoder, Encoder};
83
84 pub use websocket_codec as codec;
85
86 #[derive(Debug)]
87 pub enum WsDecodeError {
88 InvalidString,
89 InvalidUpgradeRequest,
90 InvalidHttpVersion,
91 InvalidUpgradeHeaders,
92 MissingWebSocketKey,
93 Io(std::io::Error),
94 }
95
96 #[derive(Debug)]
97 pub enum WsEncodeError {
98 Io(std::io::Error),
99 }
100
101 impl From<std::io::Error> for WsDecodeError {
102 fn from(err: std::io::Error) -> WsDecodeError {
103 WsDecodeError::Io(err)
104 }
105 }
106
107 impl From<std::io::Error> for WsEncodeError {
108 fn from(err: std::io::Error) -> WsEncodeError {
109 WsEncodeError::Io(err)
110 }
111 }
112
113 #[derive(Debug, Default)]
114 pub struct WsUpgraderCodec {}
115
116 impl WsUpgraderCodec {
117 pub fn new() -> Self {
118 Self {}
119 }
120
121 fn validate_request_line(request_line: &str) -> Result<(), WsDecodeError> {
122 let mut request_parts = request_line.split_whitespace();
123 let method = request_parts.next();
124 let uri = request_parts.next();
125 let version = request_parts.next();
126
127 match (method, uri, version) {
128 (Some(method), Some(_uri), Some(version)) => {
129 let is_get = method.eq_ignore_ascii_case("get");
130 let http_version =
131 version.split('/').nth(1).ok_or(WsDecodeError::InvalidHttpVersion)?;
132
133 let mut versions = http_version.split('.');
134 let major_str = versions.next().ok_or(WsDecodeError::InvalidHttpVersion)?;
135 let minor_str = versions.next().ok_or(WsDecodeError::InvalidHttpVersion)?;
136
137 let major: u8 =
138 major_str.parse().map_err(|_| WsDecodeError::InvalidHttpVersion)?;
139 let minor: u8 =
140 minor_str.parse().map_err(|_| WsDecodeError::InvalidHttpVersion)?;
141
142 let version_is_ok = major > 1 || (major == 1 && minor >= 1);
143
144 if is_get && version_is_ok {
145 return Ok(());
146 }
147 },
148 _ => return Err(WsDecodeError::InvalidUpgradeRequest),
149 }
150
151 Ok(())
152 }
153
154 fn validate_headers<'a>(
155 header_lines: impl Iterator<Item = &'a str>,
156 ) -> Result<&'a str, WsDecodeError> {
157 let mut websocket_key: Option<&'a str> = None;
158
159 let mut header_lines = header_lines.peekable();
160
161 while let Some(header_line) = header_lines.next() {
162 let mut split_line = header_line.split(':');
163 let header_name =
164 split_line.next().ok_or(WsDecodeError::InvalidUpgradeHeaders)?.trim();
165 let header_val =
166 split_line.next().ok_or(WsDecodeError::InvalidUpgradeHeaders)?.trim();
167
168 match header_name {
169 header if header.eq_ignore_ascii_case("Upgrade") => {
170 if header_val != "websocket" {
171 return Err(WsDecodeError::InvalidUpgradeHeaders);
172 }
173 },
174 header if header.eq_ignore_ascii_case("Connection") => {
175 if header_val != "Upgrade" {
176 return Err(WsDecodeError::InvalidUpgradeHeaders);
177 }
178 },
179 header if header.eq_ignore_ascii_case("Sec-WebSocket-Key") => {
180 websocket_key = Some(header_val);
181 },
182 header if header.eq_ignore_ascii_case("Sec-WebSocket-Version") => {
183 if header_val != "13" {
184 return Err(WsDecodeError::InvalidUpgradeHeaders);
185 }
186 },
187 header if header.eq_ignore_ascii_case("Sec-WebSocket-Protocol") => {
188 let mut versions = header_val.split(',');
189
190 if !versions.any(|proto| proto == "mqtt") {
191 return Err(WsDecodeError::InvalidUpgradeHeaders);
192 }
193 },
194 _ => {},
195 }
196
197 if header_lines.peek() == Some(&"") {
198 break;
199 }
200 }
201
202 websocket_key.ok_or(WsDecodeError::MissingWebSocketKey)
203 }
204 }
205
206 impl Decoder for WsUpgraderCodec {
207 type Error = WsDecodeError;
208 type Item = String;
209
210 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
211 match String::from_utf8(buf[..].into()) {
212 Ok(s) => {
213 let mut lines = s.split("\r\n");
214
215 if let Some(request_line) = lines.next() {
216 Self::validate_request_line(request_line)?;
217
218 let websocket_key = Self::validate_headers(lines)?;
219
220 let mut hasher = sha1::Sha1::new();
221 hasher.update(websocket_key.as_bytes());
222 hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
223 let sha1_bytes = hasher.digest().bytes();
224 let sha1_str = base64::encode(&sha1_bytes);
225
226 let _rest = buf.split_to(s.len());
227
228 Ok(Some(sha1_str))
229 } else {
230 Ok(None)
231 }
232 },
233 Err(_e) => Err(WsDecodeError::InvalidString),
234 }
235 }
236 }
237
238 impl Encoder<String> for WsUpgraderCodec {
239 type Error = WsEncodeError;
240
241 fn encode(
242 &mut self,
243 websocket_key: String,
244 bytes: &mut BytesMut,
245 ) -> Result<(), Self::Error> {
246 let response = format!(
247 "HTTP/1.1 101 Switching Protocols\r\n\
248 Upgrade: websocket\r\n\
249 Connection: Upgrade\r\n\
250 Sec-WebSocket-Protocol: mqtt\r\n\
251 Sec-WebSocket-Accept: {}\r\n\r\n",
252 websocket_key
253 );
254
255 bytes.extend_from_slice(response.as_bytes());
256 Ok(())
257 }
258 }
259}