http_rs/
request.rs

1use std::str;
2use std::collections::HashMap;
3
4use bytes::BytesMut;
5use url::Url;
6
7use crate::{
8    Method,
9    HttpResult,
10};
11
12#[derive(Debug)]
13pub struct Request {
14    pub url: Url,
15    pub method: Method,
16    pub version: u8,
17    pub headers: HashMap<String, String>,
18}
19
20impl Request {
21    /// Create a new request
22    pub fn new(payload: &BytesMut, pos_n: &[usize]) -> Self {
23        let (method, url, version) = Self::_parse_request_line(&payload, &pos_n).unwrap();
24        let headers =  Self::_parse_headers(&payload, &pos_n);
25
26        Self {
27            url,
28            method,
29            version,
30            headers,
31        }
32    }
33
34    /// Check if http request a websocket upgrade
35    ///
36    /// [Client requirements][https://tools.ietf.org/html/rfc6455#section-1.2]
37    /// [Client handshake request](https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers)
38    ///
39    /// Requirements:
40    /// ==
41    /// - **Version** HTTP/1.1
42    /// - **Method** GET
43    /// - **Headers** Upgrade, Connection
44    pub fn for_websocket(&self) -> bool {
45        if self.version != 1 || self.method != Method::GET {
46            return false
47        }
48
49        let required_headers = &[("upgrade", "websocket"), ("connection", "Upgrade")];
50
51        for (key, expected_value) in required_headers.iter() {
52            if let Some(value) = self.headers.get(*key) {
53                if value != expected_value {
54                    return false
55                }
56            } else {
57                return false
58            }
59        }
60
61        true
62    }
63
64    /// All header keys will be converted to lowercase.
65    #[inline(always)]
66    fn _parse_headers(payload: &BytesMut, pos_n: &[usize]) -> HashMap<String, String> {
67        let mut headers = HashMap::new();
68        let mut prev_idx = pos_n[0] + 1;
69
70        for idx in pos_n[1..].iter() {
71           let line = unsafe { String::from_utf8_unchecked(payload[prev_idx..*idx-1].to_vec()) };
72           if line.is_empty() { break }
73
74           let mut split = line.splitn(2, ':');
75           let key = split.next().unwrap();
76           let value = split.next().unwrap();
77
78           headers.insert(
79               key.to_string().trim().to_lowercase(),
80               value.trim().to_string()
81           );
82
83           prev_idx = idx + 1;
84        }
85
86        headers
87    }
88
89    /// Parse request line.
90    #[inline(always)]
91    fn _parse_request_line(payload: &BytesMut, pos_n: &[usize]) -> HttpResult<(Method, Url, u8)> {
92        let mut request_line = payload[..pos_n[0]-1].split(|c| *c == 0x20);
93
94        let method = match request_line.next().unwrap() {
95            b"GET" => Method::GET,
96            b"POST" => Method::POST,
97            _ => unimplemented!(),
98        };
99
100        let path = str::from_utf8(request_line.next().unwrap())?;
101        let url = Url::parse("https://hawk.local")?;
102        let url = url.join(&path)?;
103
104        let version = match request_line.next().unwrap() {
105            b"HTTP/1.1" => 1,
106            _ => unimplemented!(),
107        };
108
109        Ok((method, url, version))
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    extern crate test;
116
117    use super::*;
118    use bytes::BytesMut;
119    use crate::Method;
120    use test::Bencher;
121
122    #[test]
123    fn for_websocket() {
124        let mut headers = HashMap::new();
125        headers.insert("upgrade".to_string(), "websocket".to_string());
126        headers.insert("connection".to_string(), "Upgrade".to_string());
127
128        let mut req = Request {
129            url: Url::parse("https://hawk.local").unwrap(),
130            method: Method::GET,
131            version: 1,
132            headers,
133        };
134
135        assert!(req.for_websocket());
136
137        req.headers.remove("upgrade");
138        assert!(!req.for_websocket());
139    }
140
141    #[test]
142    fn parse_request_line() {
143        let payload = BytesMut::from(&b"GET /api/entity HTTP/1.1\r\nHost: hawk.local\r\n\r\n"[..]);
144        let pos_n = vec![25, 43, 45];
145        let (method, url, version) = Request::_parse_request_line(&payload, &pos_n).unwrap();
146
147        assert_eq!(method, Method::GET);
148        assert_eq!(url.as_str(), "https://hawk.local/api/entity");
149        assert_eq!(version, 1);
150    }
151
152    #[test]
153    fn parse_headers() {
154        let payload = BytesMut::from(&b"GET /api/entity HTTP/1.1\r\nHost: hawk.local\r\n\r\n"[..]);
155        let pos_n = vec![25, 43, 45];
156        let headers = Request::_parse_headers(&payload, &pos_n);
157
158        assert_eq!("hawk.local", headers.get("host").unwrap());
159    }
160
161    #[bench]
162    fn bench_parse_headers(b: &mut Bencher) {
163        // 611 ns/iter (+/- 26)
164        b.iter(|| parse_headers())
165    }
166
167    #[bench]
168    fn bench_parse_request_line(b: &mut Bencher) {
169        // 2,267 ns/iter (+/- 222)
170        b.iter(|| parse_request_line())
171    }
172}