lightws/handshake/
mod.rs

1//! Websocket handshake.
2
3pub mod key;
4pub mod request;
5pub mod response;
6
7pub use request::Request;
8pub use response::Response;
9pub use key::{new_sec_key, derive_accept_key};
10
11/// 32
12pub const MAX_ALLOW_HEADERS: usize = 32;
13
14/// Empty header with dummy reference
15pub const EMPTY_HEADER: HttpHeader = HttpHeader::new(b"", b"");
16
17/// 258EAFA5-E914-47DA-95CA-C5AB0DC85B11
18pub const GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
19
20/// GET
21pub const HTTP_METHOD: &[u8] = b"GET";
22
23/// HTTP/1.1
24pub const HTTP_VERSION: &[u8] = b"HTTP/1.1";
25
26/// CRLF
27pub const HTTP_LINE_BREAK: &[u8] = b"\r\n";
28
29/// A colon + one SP is prefered
30pub const HTTP_HEADER_SP: &[u8] = b": ";
31
32/// HTTP/1.1 101 Switching Protocols
33pub const HTTP_STATUS_LINE: &[u8] = b"HTTP/1.1 101 Switching Protocols";
34
35/// Http header, take two references
36#[allow(clippy::len_without_is_empty)]
37#[derive(Debug, Copy, Clone, Eq, PartialEq)]
38pub struct HttpHeader<'h> {
39    pub name: &'h [u8],
40    pub value: &'h [u8],
41}
42
43// compile time computation
44trait HeaderHelper {
45    const SIZE: usize;
46}
47
48impl<'h> HttpHeader<'h> {
49    /// Constructor, take provided name and value.
50    #[inline]
51    pub const fn new(name: &'h [u8], value: &'h [u8]) -> Self { Self { name, value } }
52
53    /// Total number of bytes(name + value + sp).
54    #[inline]
55    pub const fn len(&self) -> usize {
56        self.name.len() + self.value.len() + HTTP_HEADER_SP.len() + HTTP_LINE_BREAK.len()
57    }
58
59    /// Create [`MAX_ALLOW_HEADERS`] empty headers.
60    #[inline]
61    pub const fn new_storage() -> [HttpHeader<'static>; MAX_ALLOW_HEADERS] {
62        [EMPTY_HEADER; MAX_ALLOW_HEADERS]
63    }
64
65    /// Create N empty headers.
66    #[inline]
67    pub const fn new_custom_storage<const N: usize>() -> [HttpHeader<'static>; N] {
68        [EMPTY_HEADER; N]
69    }
70}
71
72impl Default for HttpHeader<'static> {
73    fn default() -> Self { EMPTY_HEADER }
74}
75
76impl<'h> std::fmt::Display for HttpHeader<'h> {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        use std::str::from_utf8_unchecked;
79        write!(
80            f,
81            "{}: {}",
82            unsafe { from_utf8_unchecked(self.name) },
83            unsafe { from_utf8_unchecked(self.value) }
84        )
85    }
86}
87
88macro_rules! header {
89    (   $(
90            $(#[$docs: meta])*
91            ($hdr: ident => $name: expr, $value: expr);
92        )+
93    ) => {
94        $(
95            $(#[$docs])*
96            pub const $hdr: HttpHeader = HttpHeader::new($name, $value);
97        )+
98    };
99    (   $(
100            ($hdr_name: ident => $name: expr);
101        )+
102    ) => {
103        $(
104            pub const $hdr_name: &[u8] = $name;
105        )+
106    };
107}
108
109macro_rules! write_header {
110    ($w: expr, $hdr: expr) => {
111        if $w.remaining() < $hdr.len() {
112            return Err(HandshakeError::NotEnoughCapacity);
113        } else {
114            unsafe {
115                $w.write_unchecked($hdr.name);
116                $w.write_unchecked(HTTP_HEADER_SP);
117                $w.write_unchecked($hdr.value);
118                $w.write_unchecked(HTTP_LINE_BREAK);
119            }
120        }
121    };
122    ($w: expr, $name: expr, $value: expr) => {
123        write_header!($w, HttpHeader::new($name, $value));
124    };
125}
126
127macro_rules! handshake_check {
128    ($hdr: expr, $e: expr) => {
129        if $hdr.value.is_empty() {
130            return Err($e);
131        }
132    };
133    ($hdr: expr, $value: expr, $e: expr) => {
134        // header value here is case insensitive
135        // ref: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1
136        if $hdr.value.is_empty() || !$hdr.value.eq_ignore_ascii_case($value) {
137            return Err($e);
138        }
139    };
140}
141
142use write_header;
143use handshake_check;
144
145#[inline]
146fn filter_header<'h>(
147    all: &[httparse::Header<'h>],
148    required: &mut [HttpHeader<'h>],
149    other: &mut [HttpHeader<'h>],
150) {
151    let mut other_iter = other.iter_mut();
152    for hdr in all.iter() {
153        let name = hdr.name.as_bytes();
154
155        if let Some(h) = required
156            .iter_mut()
157            .filter(|h| h.value.is_empty())
158            .find(|h| h.name.eq_ignore_ascii_case(name))
159        {
160            h.value = hdr.value;
161        } else {
162            let other_hdr = other_iter.next().unwrap();
163            other_hdr.name = name;
164            other_hdr.value = hdr.value;
165        }
166    }
167}
168
169/// Static http headers
170#[allow(unused)]
171pub mod static_headers {
172    use super::HttpHeader;
173    // header
174    header!(
175        /// host: {host}
176        (HEADER_HOST => b"host", b"");
177
178        /// upgrade: websocket
179        (HEADER_UPGRADE => b"upgrade", b"");
180
181        /// connection: upgrade
182        (HEADER_CONNECTION => b"connection", b"");
183
184        /// sec-websocket-key: {key}
185        (HEADER_SEC_WEBSOCKET_KEY => b"sec-websocket-key", b"");
186
187        /// sec-websocket-accept: {accept}
188        (HEADER_SEC_WEBSOCKET_ACCEPT => b"sec-websocket-accept", b"");
189
190        /// sec-webSocket-version: 13
191        (HEADER_SEC_WEBSOCKET_VERSION => b"sec-webSocket-version", b"");
192    );
193
194    // header name
195    header! {
196        (HEADER_HOST_NAME => b"host");
197
198        (HEADER_UPGRADE_NAME => b"upgrade");
199
200        (HEADER_CONNECTION_NAME => b"connection");
201
202        (HEADER_SEC_WEBSOCKET_KEY_NAME => b"sec-websocket-key");
203
204        (HEADER_SEC_WEBSOCKET_ACCEPT_NAME => b"sec-websocket-accept");
205
206        (HEADER_SEC_WEBSOCKET_VERSION_NAME => b"sec-websocket-version");
207    }
208
209    // header value
210    header! {
211        (HEADER_UPGRADE_VALUE => b"websocket");
212
213        (HEADER_CONNECTION_VALUE => b"upgrade");
214
215        (HEADER_SEC_WEBSOCKET_VERSION_VALUE => b"13");
216    }
217}
218
219#[cfg(test)]
220mod test {
221    use rand::prelude::*;
222
223    pub const TEMPLATE_HEADERS: &str = "\
224        host: www.example.com\r\n\
225        upgrade: websocket\r\n\
226        connection: upgrade\r\n\
227        sec-websocket-key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
228        sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
229        sec-websocket-version: 13";
230
231    pub fn make_headers(count: usize, max_len: usize, headers: &str) -> String {
232        fn rand_ascii() -> char {
233            let x: u8 = thread_rng().gen_range(1..=4);
234            let ch: u8 = match x {
235                1 => thread_rng().gen_range(b'0'..=b'9'),
236                2 => thread_rng().gen_range(b'A'..=b'Z'),
237                3 => thread_rng().gen_range(b'a'..=b'z'),
238                4 => b'-',
239                _ => unreachable!(),
240            };
241            ch as char
242        }
243
244        fn rand_str(len: usize) -> String {
245            let mut s = String::new();
246            for _ in 0..len {
247                s.push(rand_ascii());
248            }
249            s
250        }
251
252        fn make_header(max_len: usize) -> String {
253            let mut s = String::with_capacity(256);
254            let name_len: usize = thread_rng().gen_range(1..=max_len);
255            let value_len: usize = thread_rng().gen_range(1..=max_len);
256            s.push_str(&format!(
257                "{}: {}\r\n",
258                rand_str(name_len),
259                rand_str(value_len)
260            ));
261            s
262        }
263
264        let mut s = Vec::<String>::with_capacity(256);
265        for hdr in headers.split("\r\n") {
266            s.push(format!("{}\r\n", hdr));
267        }
268        for _ in 0..count {
269            s.push(make_header(max_len));
270        }
271        s.shuffle(&mut thread_rng());
272        s.concat()
273    }
274}