lightws/handshake/
request.rs

1//! Client upgrade request.
2//!
3//! From [RFC-6455 Section 4.1](https://datatracker.ietf.org/doc/html/rfc6455#section-4.1):
4//!
5//! Once a connection to the server has been established (including a
6//! connection via a proxy or over a TLS-encrypted tunnel), the client
7//! MUST send an opening handshake to the server.  The handshake consists
8//! of an HTTP Upgrade request, along with a list of required and
9//! optional header fields.
10//!
11//! Once the client's opening handshake has been sent, the client MUST
12//! wait for a response from the server before sending any further data.
13//!
14//! Example:
15//!
16//! ```text
17//! GET /path HTTP/1.1
18//! host: www.example.com
19//! upgrade: websocket
20//! connection: upgrade
21//! sec-websocket-key: dGhlIHNhbXBsZSBub25jZQ==
22//! sec-websocket-version: 13
23//! ```
24//!
25
26use super::{HttpHeader, HeaderHelper};
27use super::{write_header, filter_header};
28use super::handshake_check;
29use super::MAX_ALLOW_HEADERS;
30use super::{HTTP_METHOD, HTTP_VERSION, HTTP_LINE_BREAK, HTTP_HEADER_SP};
31use super::static_headers::*;
32
33use crate::bleed::Writer;
34use crate::error::HandshakeError;
35
36/// Http request presentation.
37pub struct Request<'h, 'b: 'h, const N: usize = MAX_ALLOW_HEADERS> {
38    pub path: &'b [u8],
39    pub host: &'b [u8],
40    pub sec_key: &'b [u8],
41    pub other_headers: &'h mut [HttpHeader<'b>],
42}
43
44impl<'h, 'b: 'h, const N: usize> HeaderHelper for Request<'h, 'b, N> {
45    const SIZE: usize = N;
46}
47
48impl<'h, 'b: 'h> Request<'h, 'b> {
49    /// Create a new request without extra headers.
50    /// This is usually used to send a request.
51    #[inline]
52    pub const fn new(path: &'b [u8], host: &'b [u8], sec_key: &'b [u8]) -> Self {
53        Self {
54            path,
55            host,
56            sec_key,
57            other_headers: &mut [],
58        }
59    }
60
61    /// Create a new request with extra headers.
62    /// This is usually used to send a request.
63    #[inline]
64    pub const fn new_with_headers(
65        path: &'b [u8],
66        host: &'b [u8],
67        sec_key: &'b [u8],
68        other_headers: &'h mut [HttpHeader<'b>],
69    ) -> Self {
70        Self {
71            path,
72            host,
73            sec_key,
74            other_headers,
75        }
76    }
77
78    /// Create with user provided headers storage, other fields are left empty.
79    /// This is usually used to receive a request.
80    ///
81    /// The max decode header size is [`MAX_ALLOW_HEADERS`].
82    #[inline]
83    pub const fn new_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self {
84        Self {
85            path: &[],
86            host: &[],
87            sec_key: &[],
88            other_headers,
89        }
90    }
91}
92
93impl<'h, 'b: 'h, const N: usize> Request<'h, 'b, N> {
94    /// Create with user provided headers storage, other fields are left empty.
95    /// This is usually used to receive a request.
96    ///
97    /// The const generic paramater represents the max decode header size.
98    #[inline]
99    pub const fn new_custom_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self {
100        Self {
101            path: &[],
102            host: &[],
103            sec_key: &[],
104            other_headers,
105        }
106    }
107
108    /// Encode to a provided buffer, return the number of written bytes.
109    ///
110    /// Necessary headers, including `host`, `upgrade`, `connection`,
111    /// `sec-websocket-key` and `sec-websocket-version` are written to
112    /// the buffer, then other headers(if any) are written in order.
113    ///
114    /// Caller should make sure there is enough space to write,
115    /// otherwise a [`HandshakeError::NotEnoughCapacity`] error will be returned.
116    pub fn encode(&self, buf: &mut [u8]) -> Result<usize, HandshakeError> {
117        debug_assert!(buf.len() > 80);
118
119        let mut w = Writer::new(buf);
120
121        // GET {path} HTTP/1.1
122        unsafe {
123            w.write_unchecked(HTTP_METHOD);
124            w.write_byte_unchecked(0x20);
125            w.write_unchecked(self.path);
126            w.write_byte_unchecked(0x20);
127            w.write_unchecked(HTTP_VERSION);
128            w.write_unchecked(HTTP_LINE_BREAK);
129        }
130
131        // host: {host}
132        write_header!(w, HEADER_HOST_NAME, self.host);
133
134        // upgrade: websocket
135        write_header!(w, HEADER_UPGRADE_NAME, HEADER_UPGRADE_VALUE);
136
137        // connection: upgrade
138        write_header!(w, HEADER_CONNECTION_NAME, HEADER_CONNECTION_VALUE);
139
140        // sec-websocket-key: {sec_key}
141        write_header!(w, HEADER_SEC_WEBSOCKET_KEY_NAME, self.sec_key);
142
143        // sec-websocket-version: 13
144        write_header!(
145            w,
146            HEADER_SEC_WEBSOCKET_VERSION_NAME,
147            HEADER_SEC_WEBSOCKET_VERSION_VALUE
148        );
149
150        // other headers
151        for hdr in self.other_headers.iter() {
152            write_header!(w, hdr)
153        }
154
155        // finish with CRLF
156        w.write_or_err(HTTP_LINE_BREAK, || HandshakeError::NotEnoughCapacity)?;
157
158        Ok(w.pos())
159    }
160
161    /// Parse from a provided buffer, save the results, and
162    /// return the number of bytes parsed.
163    ///
164    /// Necessary headers, including `host`, `upgrade`, `connection`,
165    /// `sec-websocket-key` and `sec-websocket-version` are parsed and checked,
166    /// and stored in the struct. Optional headers
167    /// (like `sec-websocket-protocol`) are stored in `other_headers`.
168    /// After the parse, `other_headers` will be shrunk to
169    /// fit the number of stored headers.
170    ///
171    /// Caller should make sure there is enough space
172    /// (default is [`MAX_ALLOW_HEADERS`]) to store headers,
173    /// which could be specified by the const generic paramater.
174    /// If the buffer does not contain a complete http request,
175    /// a [`HandshakeError::NotEnoughData`] error will be returned.
176    /// If the required headers(mentioned above) do not pass the check
177    /// (case insensitive), other corresponding errors will be returned.
178    pub fn decode(&mut self, buf: &'b [u8]) -> Result<usize, HandshakeError> {
179        debug_assert!(self.other_headers.len() >= <Self as HeaderHelper>::SIZE);
180
181        let mut headers = [httparse::EMPTY_HEADER; N];
182        let mut request = httparse::Request::new(&mut headers);
183
184        // return value
185        let decode_n = match request.parse(buf)? {
186            httparse::Status::Complete(n) => n,
187            httparse::Status::Partial => return Err(HandshakeError::NotEnoughData),
188        };
189
190        // check method
191        if request.method.unwrap().as_bytes() != HTTP_METHOD {
192            return Err(HandshakeError::HttpMethod);
193        }
194
195        // check version, should be HTTP/1.1
196        // ref: https://docs.rs/httparse/latest/src/httparse/lib.rs.html#581-596
197        if request.version.unwrap() != 1_u8 {
198            return Err(HandshakeError::HttpVersion);
199        }
200
201        // handle headers below
202        // headers are shrunk to number of inited headers
203        // ref: https://docs.rs/httparse/latest/src/httparse/lib.rs.html#757-765
204        let headers = request.headers;
205
206        let mut required_headers = [
207            HEADER_HOST,
208            HEADER_UPGRADE,
209            HEADER_CONNECTION,
210            HEADER_SEC_WEBSOCKET_KEY,
211            HEADER_SEC_WEBSOCKET_VERSION,
212        ];
213
214        // filter required headers, save other headers
215        filter_header(headers, &mut required_headers, self.other_headers);
216
217        let [host_hdr, upgrade_hdr, connection_hdr, sec_key_hdr, sec_version_hdr] =
218            required_headers;
219
220        // check missing header
221        if !required_headers.iter().all(|h| !h.value.is_empty()) {
222            handshake_check!(host_hdr, HandshakeError::HttpHost);
223            handshake_check!(upgrade_hdr, HandshakeError::Upgrade);
224            handshake_check!(connection_hdr, HandshakeError::Connection);
225            handshake_check!(sec_key_hdr, HandshakeError::SecWebSocketKey);
226            handshake_check!(sec_version_hdr, HandshakeError::SecWebSocketVersion);
227        }
228
229        // check header value (case insensitive)
230        // ref: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1
231        handshake_check!(upgrade_hdr, HEADER_UPGRADE_VALUE, HandshakeError::Upgrade);
232
233        handshake_check!(
234            connection_hdr,
235            HEADER_CONNECTION_VALUE,
236            HandshakeError::Connection
237        );
238
239        handshake_check!(
240            sec_version_hdr,
241            HEADER_SEC_WEBSOCKET_VERSION_VALUE,
242            HandshakeError::SecWebSocketVersion
243        );
244
245        // save ref
246        self.path = request.path.unwrap().as_bytes();
247        self.host = host_hdr.value;
248        self.sec_key = sec_key_hdr.value;
249
250        // shrink header reference
251        let other_header_len = headers.len() - required_headers.len();
252
253        // remove lifetime here, remember that
254        // &mut other_headers lives longer than &mut self
255        let other_headers: &'h mut [HttpHeader<'b>] =
256            unsafe { &mut *(self.other_headers as *mut _) };
257        self.other_headers = unsafe { other_headers.get_unchecked_mut(0..other_header_len) };
258
259        Ok(decode_n)
260    }
261}
262
263#[cfg(test)]
264mod test {
265    use super::*;
266    use super::super::HttpHeader;
267    use super::super::test::{make_headers, TEMPLATE_HEADERS};
268    use rand::prelude::*;
269
270    #[test]
271    fn client_handshake() {
272        for i in 0..64 {
273            let hdr_len: usize = thread_rng().gen_range(1..128);
274            let headers = format!(
275                "GET / HTTP/1.1\r\n{}\r\n",
276                make_headers(i, hdr_len, TEMPLATE_HEADERS)
277            );
278
279            let mut other_headers = HttpHeader::new_custom_storage::<1024>();
280            let mut request = Request::<1024>::new_custom_storage(&mut other_headers);
281            let decode_n = request.decode(headers.as_bytes()).unwrap();
282
283            assert_eq!(decode_n, headers.len());
284            assert_eq!(request.path, b"/");
285            assert_eq!(request.host, b"www.example.com");
286            assert_eq!(request.sec_key, b"dGhlIHNhbXBsZSBub25jZQ==");
287
288            // other headers
289            macro_rules! match_other {
290                ($name: expr, $value: expr) => {{
291                    request
292                        .other_headers
293                        .iter()
294                        .find(|hdr| hdr.name == $name && hdr.value == $value)
295                        .unwrap();
296                }};
297            }
298            match_other!(b"sec-websocket-accept", b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
299
300            let mut buf: Vec<u8> = vec![0; 0x4000];
301            let encode_n = request.encode(&mut buf).unwrap();
302            assert_eq!(encode_n, decode_n);
303        }
304    }
305
306    #[test]
307    fn client_handshake2() {
308        macro_rules! run {
309            ($host: expr, $path: expr, $sec_key: expr) => {{
310                let headers = format!(
311                    "GET {1} HTTP/1.1\r\n{0}\r\n",
312                    make_headers(
313                        16,
314                        32,
315                        &format!(
316                            "host: {0}\r\n\
317                        sec-websocket-key: {1}\r\n\
318                        upgrade: websocket\r\n\
319                        connection: upgrade\r\n\
320                        sec-websocket-version: 13",
321                            $host, $sec_key
322                        )
323                    ),
324                    $path
325                );
326
327                let mut other_headers = HttpHeader::new_storage();
328                let mut request = Request::new_storage(&mut other_headers);
329                let decode_n = request.decode(headers.as_bytes()).unwrap();
330                assert_eq!(decode_n, headers.len());
331                assert_eq!(request.host, $host.as_bytes());
332                assert_eq!(request.path, $path.as_bytes());
333                assert_eq!(request.sec_key, $sec_key.as_bytes());
334
335                let mut buf: Vec<u8> = vec![0; 0x4000];
336                let encode_n = request.encode(&mut buf).unwrap();
337                assert_eq!(encode_n, decode_n);
338            }};
339        }
340
341        run!("host", "/path", "key");
342        run!("www.abc.com", "/path/to", "xxxxxx");
343        run!("wwww.www.ww.w", "/path/to/to/path", "xxxxxxyyyy");
344    }
345
346    // catch errors ...
347}