mqtt_v5/
lib.rs

1pub const TOPIC_SEPARATOR: char = '/';
2
3pub const MULTI_LEVEL_WILDCARD: char = '#';
4pub const MULTI_LEVEL_WILDCARD_STR: &str = "#";
5
6pub const SINGLE_LEVEL_WILDCARD: char = '+';
7pub const SINGLE_LEVEL_WILDCARD_STR: &str = "+";
8
9pub const SHARED_SUBSCRIPTION_PREFIX: &str = "$share/";
10
11pub const MAX_TOPIC_LEN_BYTES: usize = 65_535;
12
13pub mod decoder;
14pub mod encoder;
15pub mod topic;
16pub mod types;
17
18#[cfg(feature = "codec")]
19pub mod codec {
20    use crate::{
21        decoder, encoder,
22        types::{DecodeError, EncodeError, Packet, ProtocolVersion},
23    };
24    use bytes::BytesMut;
25    use tokio_util::codec::{Decoder, Encoder};
26
27    pub struct MqttCodec {
28        version: ProtocolVersion,
29    }
30
31    impl Default for MqttCodec {
32        fn default() -> Self {
33            MqttCodec::new()
34        }
35    }
36
37    impl MqttCodec {
38        pub fn new() -> Self {
39            MqttCodec { version: ProtocolVersion::V311 }
40        }
41
42        pub fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Packet>, DecodeError> {
43            // TODO - Ideally we should keep a state machine to store the data we've read so far.
44            let packet = decoder::decode_mqtt(buf, self.version);
45
46            if let Ok(Some(Packet::Connect(packet))) = &packet {
47                self.version = packet.protocol_version;
48            }
49
50            packet
51        }
52
53        pub fn encode(&mut self, packet: Packet, bytes: &mut BytesMut) -> Result<(), EncodeError> {
54            encoder::encode_mqtt(&packet, bytes, self.version);
55            Ok(())
56        }
57    }
58
59    impl Decoder for MqttCodec {
60        type Error = DecodeError;
61        type Item = Packet;
62
63        fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
64            // TODO - Ideally we should keep a state machine to store the data we've read so far.
65            self.decode(buf)
66        }
67    }
68
69    impl Encoder<Packet> for MqttCodec {
70        type Error = EncodeError;
71
72        fn encode(&mut self, packet: Packet, bytes: &mut BytesMut) -> Result<(), Self::Error> {
73            self.encode(packet, bytes)
74        }
75    }
76}
77
78#[cfg(feature = "websocket")]
79pub mod websocket {
80    use bytes::BytesMut;
81    use tokio_util::codec::{Decoder, Encoder};
82
83    pub use websocket_codec as codec;
84
85    #[derive(Debug)]
86    pub enum WsDecodeError {
87        InvalidString,
88        InvalidUpgradeRequest,
89        InvalidHttpVersion,
90        InvalidUpgradeHeaders,
91        MissingWebSocketKey,
92        Io(std::io::Error),
93    }
94
95    #[derive(Debug)]
96    pub enum WsEncodeError {
97        Io(std::io::Error),
98    }
99
100    impl From<std::io::Error> for WsDecodeError {
101        fn from(err: std::io::Error) -> WsDecodeError {
102            WsDecodeError::Io(err)
103        }
104    }
105
106    impl From<std::io::Error> for WsEncodeError {
107        fn from(err: std::io::Error) -> WsEncodeError {
108            WsEncodeError::Io(err)
109        }
110    }
111
112    #[derive(Debug)]
113    pub struct WsUpgraderCodec {}
114
115    impl WsUpgraderCodec {
116        pub fn new() -> Self {
117            Self {}
118        }
119
120        fn validate_request_line(request_line: &str) -> Result<(), WsDecodeError> {
121            let mut request_parts = request_line.split_whitespace();
122            let method = request_parts.next();
123            let uri = request_parts.next();
124            let version = request_parts.next();
125
126            match (method, uri, version) {
127                (Some(method), Some(_uri), Some(version)) => {
128                    let is_get = method.eq_ignore_ascii_case("get");
129                    let http_version =
130                        version.split('/').nth(1).ok_or(WsDecodeError::InvalidHttpVersion)?;
131
132                    let mut versions = http_version.split('.');
133                    let major_str = versions.next().ok_or(WsDecodeError::InvalidHttpVersion)?;
134                    let minor_str = versions.next().ok_or(WsDecodeError::InvalidHttpVersion)?;
135
136                    let major: u8 =
137                        major_str.parse().map_err(|_| WsDecodeError::InvalidHttpVersion)?;
138                    let minor: u8 =
139                        minor_str.parse().map_err(|_| WsDecodeError::InvalidHttpVersion)?;
140
141                    let version_is_ok = major > 1 || (major == 1 && minor >= 1);
142
143                    if is_get && version_is_ok {
144                        return Ok(());
145                    }
146                },
147                _ => return Err(WsDecodeError::InvalidUpgradeRequest),
148            }
149
150            Ok(())
151        }
152
153        fn validate_headers<'a>(
154            header_lines: impl Iterator<Item = &'a str>,
155        ) -> Result<&'a str, WsDecodeError> {
156            let mut websocket_key: Option<&'a str> = None;
157
158            let mut header_lines = header_lines.peekable();
159
160            while let Some(header_line) = header_lines.next() {
161                let mut split_line = header_line.split(':');
162                let header_name =
163                    split_line.next().ok_or(WsDecodeError::InvalidUpgradeHeaders)?.trim();
164                let header_val =
165                    split_line.next().ok_or(WsDecodeError::InvalidUpgradeHeaders)?.trim();
166
167                match header_name {
168                    header if header.eq_ignore_ascii_case("Upgrade") => {
169                        if header_val != "websocket" {
170                            return Err(WsDecodeError::InvalidUpgradeHeaders);
171                        }
172                    },
173                    header if header.eq_ignore_ascii_case("Connection") => {
174                        if header_val != "Upgrade" {
175                            return Err(WsDecodeError::InvalidUpgradeHeaders);
176                        }
177                    },
178                    header if header.eq_ignore_ascii_case("Sec-WebSocket-Key") => {
179                        websocket_key = Some(header_val);
180                    },
181                    header if header.eq_ignore_ascii_case("Sec-WebSocket-Version") => {
182                        if header_val != "13" {
183                            return Err(WsDecodeError::InvalidUpgradeHeaders);
184                        }
185                    },
186                    header if header.eq_ignore_ascii_case("Sec-WebSocket-Protocol") => {
187                        let mut versions = header_val.split(',');
188
189                        if !versions.any(|proto| proto == "mqtt") {
190                            return Err(WsDecodeError::InvalidUpgradeHeaders);
191                        }
192                    },
193                    _ => {},
194                }
195
196                if header_lines.peek() == Some(&"") {
197                    break;
198                }
199            }
200
201            websocket_key.ok_or(WsDecodeError::MissingWebSocketKey)
202        }
203    }
204
205    impl Default for WsUpgraderCodec {
206        fn default() -> Self {
207            WsUpgraderCodec {}
208        }
209    }
210
211    impl Decoder for WsUpgraderCodec {
212        type Error = WsDecodeError;
213        type Item = String;
214
215        fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
216            match String::from_utf8(buf[..].into()) {
217                Ok(s) => {
218                    let mut lines = s.split("\r\n");
219
220                    if let Some(request_line) = lines.next() {
221                        Self::validate_request_line(request_line)?;
222
223                        let websocket_key = Self::validate_headers(lines)?;
224
225                        let mut hasher = sha1::Sha1::new();
226                        hasher.update(websocket_key.as_bytes());
227                        hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
228                        let sha1_bytes = hasher.digest().bytes();
229                        let sha1_str = base64::encode(&sha1_bytes);
230
231                        let _rest = buf.split_to(s.len());
232
233                        Ok(Some(sha1_str))
234                    } else {
235                        Ok(None)
236                    }
237                },
238                Err(_e) => Err(WsDecodeError::InvalidString),
239            }
240        }
241    }
242
243    impl Encoder<String> for WsUpgraderCodec {
244        type Error = WsEncodeError;
245
246        fn encode(
247            &mut self,
248            websocket_key: String,
249            bytes: &mut BytesMut,
250        ) -> Result<(), Self::Error> {
251            let response = format!(
252                "HTTP/1.1 101 Switching Protocols\r\n\
253                Upgrade: websocket\r\n\
254                Connection: Upgrade\r\n\
255                Sec-WebSocket-Protocol: mqtt\r\n\
256                Sec-WebSocket-Accept: {}\r\n\r\n",
257                websocket_key
258            );
259
260            bytes.extend_from_slice(response.as_bytes());
261            Ok(())
262        }
263    }
264}