1pub const TOPIC_SEPARATOR: char = '/';
2
3pub const MULTI_LEVEL_WILDCARD: char = '#';
4pub const MULTI_LEVEL_WILDCARD_STR: &str = "#";
5
6pub const SINGLE_LEVEL_WILDCARD: char = '+';
7pub const SINGLE_LEVEL_WILDCARD_STR: &str = "+";
8
9pub const SHARED_SUBSCRIPTION_PREFIX: &str = "$share/";
10
11pub const MAX_TOPIC_LEN_BYTES: usize = 65_535;
12
13pub mod decoder;
14pub mod encoder;
15pub mod topic;
16pub mod types;
17
18#[cfg(feature = "codec")]
19pub mod codec {
20 use crate::{
21 decoder, encoder,
22 types::{DecodeError, EncodeError, Packet, ProtocolVersion},
23 };
24 use bytes::BytesMut;
25 use tokio_util::codec::{Decoder, Encoder};
26
27 pub struct MqttCodec {
28 version: ProtocolVersion,
29 }
30
31 impl Default for MqttCodec {
32 fn default() -> Self {
33 MqttCodec::new()
34 }
35 }
36
37 impl MqttCodec {
38 pub fn new() -> Self {
39 MqttCodec { version: ProtocolVersion::V311 }
40 }
41
42 pub fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Packet>, DecodeError> {
43 let packet = decoder::decode_mqtt(buf, self.version);
45
46 if let Ok(Some(Packet::Connect(packet))) = &packet {
47 self.version = packet.protocol_version;
48 }
49
50 packet
51 }
52
53 pub fn encode(&mut self, packet: Packet, bytes: &mut BytesMut) -> Result<(), EncodeError> {
54 encoder::encode_mqtt(&packet, bytes, self.version);
55 Ok(())
56 }
57 }
58
59 impl Decoder for MqttCodec {
60 type Error = DecodeError;
61 type Item = Packet;
62
63 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
64 self.decode(buf)
66 }
67 }
68
69 impl Encoder<Packet> for MqttCodec {
70 type Error = EncodeError;
71
72 fn encode(&mut self, packet: Packet, bytes: &mut BytesMut) -> Result<(), Self::Error> {
73 self.encode(packet, bytes)
74 }
75 }
76}
77
78#[cfg(feature = "websocket")]
79pub mod websocket {
80 use bytes::BytesMut;
81 use tokio_util::codec::{Decoder, Encoder};
82
83 pub use websocket_codec as codec;
84
85 #[derive(Debug)]
86 pub enum WsDecodeError {
87 InvalidString,
88 InvalidUpgradeRequest,
89 InvalidHttpVersion,
90 InvalidUpgradeHeaders,
91 MissingWebSocketKey,
92 Io(std::io::Error),
93 }
94
95 #[derive(Debug)]
96 pub enum WsEncodeError {
97 Io(std::io::Error),
98 }
99
100 impl From<std::io::Error> for WsDecodeError {
101 fn from(err: std::io::Error) -> WsDecodeError {
102 WsDecodeError::Io(err)
103 }
104 }
105
106 impl From<std::io::Error> for WsEncodeError {
107 fn from(err: std::io::Error) -> WsEncodeError {
108 WsEncodeError::Io(err)
109 }
110 }
111
112 #[derive(Debug)]
113 pub struct WsUpgraderCodec {}
114
115 impl WsUpgraderCodec {
116 pub fn new() -> Self {
117 Self {}
118 }
119
120 fn validate_request_line(request_line: &str) -> Result<(), WsDecodeError> {
121 let mut request_parts = request_line.split_whitespace();
122 let method = request_parts.next();
123 let uri = request_parts.next();
124 let version = request_parts.next();
125
126 match (method, uri, version) {
127 (Some(method), Some(_uri), Some(version)) => {
128 let is_get = method.eq_ignore_ascii_case("get");
129 let http_version =
130 version.split('/').nth(1).ok_or(WsDecodeError::InvalidHttpVersion)?;
131
132 let mut versions = http_version.split('.');
133 let major_str = versions.next().ok_or(WsDecodeError::InvalidHttpVersion)?;
134 let minor_str = versions.next().ok_or(WsDecodeError::InvalidHttpVersion)?;
135
136 let major: u8 =
137 major_str.parse().map_err(|_| WsDecodeError::InvalidHttpVersion)?;
138 let minor: u8 =
139 minor_str.parse().map_err(|_| WsDecodeError::InvalidHttpVersion)?;
140
141 let version_is_ok = major > 1 || (major == 1 && minor >= 1);
142
143 if is_get && version_is_ok {
144 return Ok(());
145 }
146 },
147 _ => return Err(WsDecodeError::InvalidUpgradeRequest),
148 }
149
150 Ok(())
151 }
152
153 fn validate_headers<'a>(
154 header_lines: impl Iterator<Item = &'a str>,
155 ) -> Result<&'a str, WsDecodeError> {
156 let mut websocket_key: Option<&'a str> = None;
157
158 let mut header_lines = header_lines.peekable();
159
160 while let Some(header_line) = header_lines.next() {
161 let mut split_line = header_line.split(':');
162 let header_name =
163 split_line.next().ok_or(WsDecodeError::InvalidUpgradeHeaders)?.trim();
164 let header_val =
165 split_line.next().ok_or(WsDecodeError::InvalidUpgradeHeaders)?.trim();
166
167 match header_name {
168 header if header.eq_ignore_ascii_case("Upgrade") => {
169 if header_val != "websocket" {
170 return Err(WsDecodeError::InvalidUpgradeHeaders);
171 }
172 },
173 header if header.eq_ignore_ascii_case("Connection") => {
174 if header_val != "Upgrade" {
175 return Err(WsDecodeError::InvalidUpgradeHeaders);
176 }
177 },
178 header if header.eq_ignore_ascii_case("Sec-WebSocket-Key") => {
179 websocket_key = Some(header_val);
180 },
181 header if header.eq_ignore_ascii_case("Sec-WebSocket-Version") => {
182 if header_val != "13" {
183 return Err(WsDecodeError::InvalidUpgradeHeaders);
184 }
185 },
186 header if header.eq_ignore_ascii_case("Sec-WebSocket-Protocol") => {
187 let mut versions = header_val.split(',');
188
189 if !versions.any(|proto| proto == "mqtt") {
190 return Err(WsDecodeError::InvalidUpgradeHeaders);
191 }
192 },
193 _ => {},
194 }
195
196 if header_lines.peek() == Some(&"") {
197 break;
198 }
199 }
200
201 websocket_key.ok_or(WsDecodeError::MissingWebSocketKey)
202 }
203 }
204
205 impl Default for WsUpgraderCodec {
206 fn default() -> Self {
207 WsUpgraderCodec {}
208 }
209 }
210
211 impl Decoder for WsUpgraderCodec {
212 type Error = WsDecodeError;
213 type Item = String;
214
215 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
216 match String::from_utf8(buf[..].into()) {
217 Ok(s) => {
218 let mut lines = s.split("\r\n");
219
220 if let Some(request_line) = lines.next() {
221 Self::validate_request_line(request_line)?;
222
223 let websocket_key = Self::validate_headers(lines)?;
224
225 let mut hasher = sha1::Sha1::new();
226 hasher.update(websocket_key.as_bytes());
227 hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
228 let sha1_bytes = hasher.digest().bytes();
229 let sha1_str = base64::encode(&sha1_bytes);
230
231 let _rest = buf.split_to(s.len());
232
233 Ok(Some(sha1_str))
234 } else {
235 Ok(None)
236 }
237 },
238 Err(_e) => Err(WsDecodeError::InvalidString),
239 }
240 }
241 }
242
243 impl Encoder<String> for WsUpgraderCodec {
244 type Error = WsEncodeError;
245
246 fn encode(
247 &mut self,
248 websocket_key: String,
249 bytes: &mut BytesMut,
250 ) -> Result<(), Self::Error> {
251 let response = format!(
252 "HTTP/1.1 101 Switching Protocols\r\n\
253 Upgrade: websocket\r\n\
254 Connection: Upgrade\r\n\
255 Sec-WebSocket-Protocol: mqtt\r\n\
256 Sec-WebSocket-Accept: {}\r\n\r\n",
257 websocket_key
258 );
259
260 bytes.extend_from_slice(response.as_bytes());
261 Ok(())
262 }
263 }
264}