1use super::{HttpHeader, HeaderHelper};
27use super::{write_header, filter_header};
28use super::handshake_check;
29use super::MAX_ALLOW_HEADERS;
30use super::{HTTP_METHOD, HTTP_VERSION, HTTP_LINE_BREAK, HTTP_HEADER_SP};
31use super::static_headers::*;
32
33use crate::bleed::Writer;
34use crate::error::HandshakeError;
35
36pub struct Request<'h, 'b: 'h, const N: usize = MAX_ALLOW_HEADERS> {
38 pub path: &'b [u8],
39 pub host: &'b [u8],
40 pub sec_key: &'b [u8],
41 pub other_headers: &'h mut [HttpHeader<'b>],
42}
43
44impl<'h, 'b: 'h, const N: usize> HeaderHelper for Request<'h, 'b, N> {
45 const SIZE: usize = N;
46}
47
48impl<'h, 'b: 'h> Request<'h, 'b> {
49 #[inline]
52 pub const fn new(path: &'b [u8], host: &'b [u8], sec_key: &'b [u8]) -> Self {
53 Self {
54 path,
55 host,
56 sec_key,
57 other_headers: &mut [],
58 }
59 }
60
61 #[inline]
64 pub const fn new_with_headers(
65 path: &'b [u8],
66 host: &'b [u8],
67 sec_key: &'b [u8],
68 other_headers: &'h mut [HttpHeader<'b>],
69 ) -> Self {
70 Self {
71 path,
72 host,
73 sec_key,
74 other_headers,
75 }
76 }
77
78 #[inline]
83 pub const fn new_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self {
84 Self {
85 path: &[],
86 host: &[],
87 sec_key: &[],
88 other_headers,
89 }
90 }
91}
92
93impl<'h, 'b: 'h, const N: usize> Request<'h, 'b, N> {
94 #[inline]
99 pub const fn new_custom_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self {
100 Self {
101 path: &[],
102 host: &[],
103 sec_key: &[],
104 other_headers,
105 }
106 }
107
108 pub fn encode(&self, buf: &mut [u8]) -> Result<usize, HandshakeError> {
117 debug_assert!(buf.len() > 80);
118
119 let mut w = Writer::new(buf);
120
121 unsafe {
123 w.write_unchecked(HTTP_METHOD);
124 w.write_byte_unchecked(0x20);
125 w.write_unchecked(self.path);
126 w.write_byte_unchecked(0x20);
127 w.write_unchecked(HTTP_VERSION);
128 w.write_unchecked(HTTP_LINE_BREAK);
129 }
130
131 write_header!(w, HEADER_HOST_NAME, self.host);
133
134 write_header!(w, HEADER_UPGRADE_NAME, HEADER_UPGRADE_VALUE);
136
137 write_header!(w, HEADER_CONNECTION_NAME, HEADER_CONNECTION_VALUE);
139
140 write_header!(w, HEADER_SEC_WEBSOCKET_KEY_NAME, self.sec_key);
142
143 write_header!(
145 w,
146 HEADER_SEC_WEBSOCKET_VERSION_NAME,
147 HEADER_SEC_WEBSOCKET_VERSION_VALUE
148 );
149
150 for hdr in self.other_headers.iter() {
152 write_header!(w, hdr)
153 }
154
155 w.write_or_err(HTTP_LINE_BREAK, || HandshakeError::NotEnoughCapacity)?;
157
158 Ok(w.pos())
159 }
160
161 pub fn decode(&mut self, buf: &'b [u8]) -> Result<usize, HandshakeError> {
179 debug_assert!(self.other_headers.len() >= <Self as HeaderHelper>::SIZE);
180
181 let mut headers = [httparse::EMPTY_HEADER; N];
182 let mut request = httparse::Request::new(&mut headers);
183
184 let decode_n = match request.parse(buf)? {
186 httparse::Status::Complete(n) => n,
187 httparse::Status::Partial => return Err(HandshakeError::NotEnoughData),
188 };
189
190 if request.method.unwrap().as_bytes() != HTTP_METHOD {
192 return Err(HandshakeError::HttpMethod);
193 }
194
195 if request.version.unwrap() != 1_u8 {
198 return Err(HandshakeError::HttpVersion);
199 }
200
201 let headers = request.headers;
205
206 let mut required_headers = [
207 HEADER_HOST,
208 HEADER_UPGRADE,
209 HEADER_CONNECTION,
210 HEADER_SEC_WEBSOCKET_KEY,
211 HEADER_SEC_WEBSOCKET_VERSION,
212 ];
213
214 filter_header(headers, &mut required_headers, self.other_headers);
216
217 let [host_hdr, upgrade_hdr, connection_hdr, sec_key_hdr, sec_version_hdr] =
218 required_headers;
219
220 if !required_headers.iter().all(|h| !h.value.is_empty()) {
222 handshake_check!(host_hdr, HandshakeError::HttpHost);
223 handshake_check!(upgrade_hdr, HandshakeError::Upgrade);
224 handshake_check!(connection_hdr, HandshakeError::Connection);
225 handshake_check!(sec_key_hdr, HandshakeError::SecWebSocketKey);
226 handshake_check!(sec_version_hdr, HandshakeError::SecWebSocketVersion);
227 }
228
229 handshake_check!(upgrade_hdr, HEADER_UPGRADE_VALUE, HandshakeError::Upgrade);
232
233 handshake_check!(
234 connection_hdr,
235 HEADER_CONNECTION_VALUE,
236 HandshakeError::Connection
237 );
238
239 handshake_check!(
240 sec_version_hdr,
241 HEADER_SEC_WEBSOCKET_VERSION_VALUE,
242 HandshakeError::SecWebSocketVersion
243 );
244
245 self.path = request.path.unwrap().as_bytes();
247 self.host = host_hdr.value;
248 self.sec_key = sec_key_hdr.value;
249
250 let other_header_len = headers.len() - required_headers.len();
252
253 let other_headers: &'h mut [HttpHeader<'b>] =
256 unsafe { &mut *(self.other_headers as *mut _) };
257 self.other_headers = unsafe { other_headers.get_unchecked_mut(0..other_header_len) };
258
259 Ok(decode_n)
260 }
261}
262
263#[cfg(test)]
264mod test {
265 use super::*;
266 use super::super::HttpHeader;
267 use super::super::test::{make_headers, TEMPLATE_HEADERS};
268 use rand::prelude::*;
269
270 #[test]
271 fn client_handshake() {
272 for i in 0..64 {
273 let hdr_len: usize = thread_rng().gen_range(1..128);
274 let headers = format!(
275 "GET / HTTP/1.1\r\n{}\r\n",
276 make_headers(i, hdr_len, TEMPLATE_HEADERS)
277 );
278
279 let mut other_headers = HttpHeader::new_custom_storage::<1024>();
280 let mut request = Request::<1024>::new_custom_storage(&mut other_headers);
281 let decode_n = request.decode(headers.as_bytes()).unwrap();
282
283 assert_eq!(decode_n, headers.len());
284 assert_eq!(request.path, b"/");
285 assert_eq!(request.host, b"www.example.com");
286 assert_eq!(request.sec_key, b"dGhlIHNhbXBsZSBub25jZQ==");
287
288 macro_rules! match_other {
290 ($name: expr, $value: expr) => {{
291 request
292 .other_headers
293 .iter()
294 .find(|hdr| hdr.name == $name && hdr.value == $value)
295 .unwrap();
296 }};
297 }
298 match_other!(b"sec-websocket-accept", b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
299
300 let mut buf: Vec<u8> = vec![0; 0x4000];
301 let encode_n = request.encode(&mut buf).unwrap();
302 assert_eq!(encode_n, decode_n);
303 }
304 }
305
306 #[test]
307 fn client_handshake2() {
308 macro_rules! run {
309 ($host: expr, $path: expr, $sec_key: expr) => {{
310 let headers = format!(
311 "GET {1} HTTP/1.1\r\n{0}\r\n",
312 make_headers(
313 16,
314 32,
315 &format!(
316 "host: {0}\r\n\
317 sec-websocket-key: {1}\r\n\
318 upgrade: websocket\r\n\
319 connection: upgrade\r\n\
320 sec-websocket-version: 13",
321 $host, $sec_key
322 )
323 ),
324 $path
325 );
326
327 let mut other_headers = HttpHeader::new_storage();
328 let mut request = Request::new_storage(&mut other_headers);
329 let decode_n = request.decode(headers.as_bytes()).unwrap();
330 assert_eq!(decode_n, headers.len());
331 assert_eq!(request.host, $host.as_bytes());
332 assert_eq!(request.path, $path.as_bytes());
333 assert_eq!(request.sec_key, $sec_key.as_bytes());
334
335 let mut buf: Vec<u8> = vec![0; 0x4000];
336 let encode_n = request.encode(&mut buf).unwrap();
337 assert_eq!(encode_n, decode_n);
338 }};
339 }
340
341 run!("host", "/path", "key");
342 run!("www.abc.com", "/path/to", "xxxxxx");
343 run!("wwww.www.ww.w", "/path/to/to/path", "xxxxxxyyyy");
344 }
345
346 }