mqtt_v5_fork/
lib.rs

1#![forbid(unsafe_code)]
2pub const TOPIC_SEPARATOR: char = '/';
3
4pub const MULTI_LEVEL_WILDCARD: char = '#';
5pub const MULTI_LEVEL_WILDCARD_STR: &str = "#";
6
7pub const SINGLE_LEVEL_WILDCARD: char = '+';
8pub const SINGLE_LEVEL_WILDCARD_STR: &str = "+";
9
10pub const SHARED_SUBSCRIPTION_PREFIX: &str = "$share/";
11
12pub const MAX_TOPIC_LEN_BYTES: usize = 65_535;
13
14pub mod decoder;
15pub mod encoder;
16pub mod topic;
17pub mod types;
18
19#[cfg(feature = "codec")]
20pub mod codec {
21    use crate::{
22        decoder, encoder,
23        types::{DecodeError, EncodeError, Packet, ProtocolVersion},
24    };
25    use bytes::BytesMut;
26    use tokio_util::codec::{Decoder, Encoder};
27
28    pub struct MqttCodec {
29        version: ProtocolVersion,
30    }
31
32    impl Default for MqttCodec {
33        fn default() -> Self {
34            MqttCodec::new()
35        }
36    }
37
38    impl MqttCodec {
39        pub fn new() -> Self {
40            MqttCodec { version: ProtocolVersion::V311 }
41        }
42
43        pub fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Packet>, DecodeError> {
44            // TODO - Ideally we should keep a state machine to store the data we've read so far.
45            let packet = decoder::decode_mqtt(buf, self.version);
46
47            if let Ok(Some(Packet::Connect(packet))) = &packet {
48                self.version = packet.protocol_version;
49            }
50
51            packet
52        }
53
54        pub fn encode(&mut self, packet: Packet, bytes: &mut BytesMut) -> Result<(), EncodeError> {
55            encoder::encode_mqtt(&packet, bytes, self.version);
56            Ok(())
57        }
58    }
59
60    impl Decoder for MqttCodec {
61        type Error = DecodeError;
62        type Item = Packet;
63
64        fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
65            // TODO - Ideally we should keep a state machine to store the data we've read so far.
66            self.decode(buf)
67        }
68    }
69
70    impl Encoder<Packet> for MqttCodec {
71        type Error = EncodeError;
72
73        fn encode(&mut self, packet: Packet, bytes: &mut BytesMut) -> Result<(), Self::Error> {
74            self.encode(packet, bytes)
75        }
76    }
77}
78
79#[cfg(feature = "websocket")]
80pub mod websocket {
81    use bytes::BytesMut;
82    use tokio_util::codec::{Decoder, Encoder};
83
84    pub use websocket_codec as codec;
85
86    #[derive(Debug)]
87    pub enum WsDecodeError {
88        InvalidString,
89        InvalidUpgradeRequest,
90        InvalidHttpVersion,
91        InvalidUpgradeHeaders,
92        MissingWebSocketKey,
93        Io(std::io::Error),
94    }
95
96    #[derive(Debug)]
97    pub enum WsEncodeError {
98        Io(std::io::Error),
99    }
100
101    impl From<std::io::Error> for WsDecodeError {
102        fn from(err: std::io::Error) -> WsDecodeError {
103            WsDecodeError::Io(err)
104        }
105    }
106
107    impl From<std::io::Error> for WsEncodeError {
108        fn from(err: std::io::Error) -> WsEncodeError {
109            WsEncodeError::Io(err)
110        }
111    }
112
113    #[derive(Debug, Default)]
114    pub struct WsUpgraderCodec {}
115
116    impl WsUpgraderCodec {
117        pub fn new() -> Self {
118            Self {}
119        }
120
121        fn validate_request_line(request_line: &str) -> Result<(), WsDecodeError> {
122            let mut request_parts = request_line.split_whitespace();
123            let method = request_parts.next();
124            let uri = request_parts.next();
125            let version = request_parts.next();
126
127            match (method, uri, version) {
128                (Some(method), Some(_uri), Some(version)) => {
129                    let is_get = method.eq_ignore_ascii_case("get");
130                    let http_version =
131                        version.split('/').nth(1).ok_or(WsDecodeError::InvalidHttpVersion)?;
132
133                    let mut versions = http_version.split('.');
134                    let major_str = versions.next().ok_or(WsDecodeError::InvalidHttpVersion)?;
135                    let minor_str = versions.next().ok_or(WsDecodeError::InvalidHttpVersion)?;
136
137                    let major: u8 =
138                        major_str.parse().map_err(|_| WsDecodeError::InvalidHttpVersion)?;
139                    let minor: u8 =
140                        minor_str.parse().map_err(|_| WsDecodeError::InvalidHttpVersion)?;
141
142                    let version_is_ok = major > 1 || (major == 1 && minor >= 1);
143
144                    if is_get && version_is_ok {
145                        return Ok(());
146                    }
147                },
148                _ => return Err(WsDecodeError::InvalidUpgradeRequest),
149            }
150
151            Ok(())
152        }
153
154        fn validate_headers<'a>(
155            header_lines: impl Iterator<Item = &'a str>,
156        ) -> Result<&'a str, WsDecodeError> {
157            let mut websocket_key: Option<&'a str> = None;
158
159            let mut header_lines = header_lines.peekable();
160
161            while let Some(header_line) = header_lines.next() {
162                let mut split_line = header_line.split(':');
163                let header_name =
164                    split_line.next().ok_or(WsDecodeError::InvalidUpgradeHeaders)?.trim();
165                let header_val =
166                    split_line.next().ok_or(WsDecodeError::InvalidUpgradeHeaders)?.trim();
167
168                match header_name {
169                    header if header.eq_ignore_ascii_case("Upgrade") => {
170                        if header_val != "websocket" {
171                            return Err(WsDecodeError::InvalidUpgradeHeaders);
172                        }
173                    },
174                    header if header.eq_ignore_ascii_case("Connection") => {
175                        if header_val != "Upgrade" {
176                            return Err(WsDecodeError::InvalidUpgradeHeaders);
177                        }
178                    },
179                    header if header.eq_ignore_ascii_case("Sec-WebSocket-Key") => {
180                        websocket_key = Some(header_val);
181                    },
182                    header if header.eq_ignore_ascii_case("Sec-WebSocket-Version") => {
183                        if header_val != "13" {
184                            return Err(WsDecodeError::InvalidUpgradeHeaders);
185                        }
186                    },
187                    header if header.eq_ignore_ascii_case("Sec-WebSocket-Protocol") => {
188                        let mut versions = header_val.split(',');
189
190                        if !versions.any(|proto| proto == "mqtt") {
191                            return Err(WsDecodeError::InvalidUpgradeHeaders);
192                        }
193                    },
194                    _ => {},
195                }
196
197                if header_lines.peek() == Some(&"") {
198                    break;
199                }
200            }
201
202            websocket_key.ok_or(WsDecodeError::MissingWebSocketKey)
203        }
204    }
205
206    impl Decoder for WsUpgraderCodec {
207        type Error = WsDecodeError;
208        type Item = String;
209
210        fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
211            match String::from_utf8(buf[..].into()) {
212                Ok(s) => {
213                    let mut lines = s.split("\r\n");
214
215                    if let Some(request_line) = lines.next() {
216                        Self::validate_request_line(request_line)?;
217
218                        let websocket_key = Self::validate_headers(lines)?;
219
220                        let mut hasher = sha1::Sha1::new();
221                        hasher.update(websocket_key.as_bytes());
222                        hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
223                        let sha1_bytes = hasher.digest().bytes();
224                        let sha1_str = base64::encode(&sha1_bytes);
225
226                        let _rest = buf.split_to(s.len());
227
228                        Ok(Some(sha1_str))
229                    } else {
230                        Ok(None)
231                    }
232                },
233                Err(_e) => Err(WsDecodeError::InvalidString),
234            }
235        }
236    }
237
238    impl Encoder<String> for WsUpgraderCodec {
239        type Error = WsEncodeError;
240
241        fn encode(
242            &mut self,
243            websocket_key: String,
244            bytes: &mut BytesMut,
245        ) -> Result<(), Self::Error> {
246            let response = format!(
247                "HTTP/1.1 101 Switching Protocols\r\n\
248                 Upgrade: websocket\r\n\
249                 Connection: Upgrade\r\n\
250                 Sec-WebSocket-Protocol: mqtt\r\n\
251                 Sec-WebSocket-Accept: {}\r\n\r\n",
252                websocket_key
253            );
254
255            bytes.extend_from_slice(response.as_bytes());
256            Ok(())
257        }
258    }
259}