1use std::str;
2use std::collections::HashMap;
3
4use bytes::BytesMut;
5use url::Url;
6
7use crate::{
8 Method,
9 HttpResult,
10};
11
12#[derive(Debug)]
13pub struct Request {
14 pub url: Url,
15 pub method: Method,
16 pub version: u8,
17 pub headers: HashMap<String, String>,
18}
19
20impl Request {
21 pub fn new(payload: &BytesMut, pos_n: &[usize]) -> Self {
23 let (method, url, version) = Self::_parse_request_line(&payload, &pos_n).unwrap();
24 let headers = Self::_parse_headers(&payload, &pos_n);
25
26 Self {
27 url,
28 method,
29 version,
30 headers,
31 }
32 }
33
34 pub fn for_websocket(&self) -> bool {
45 if self.version != 1 || self.method != Method::GET {
46 return false
47 }
48
49 let required_headers = &[("upgrade", "websocket"), ("connection", "Upgrade")];
50
51 for (key, expected_value) in required_headers.iter() {
52 if let Some(value) = self.headers.get(*key) {
53 if value != expected_value {
54 return false
55 }
56 } else {
57 return false
58 }
59 }
60
61 true
62 }
63
64 #[inline(always)]
66 fn _parse_headers(payload: &BytesMut, pos_n: &[usize]) -> HashMap<String, String> {
67 let mut headers = HashMap::new();
68 let mut prev_idx = pos_n[0] + 1;
69
70 for idx in pos_n[1..].iter() {
71 let line = unsafe { String::from_utf8_unchecked(payload[prev_idx..*idx-1].to_vec()) };
72 if line.is_empty() { break }
73
74 let mut split = line.splitn(2, ':');
75 let key = split.next().unwrap();
76 let value = split.next().unwrap();
77
78 headers.insert(
79 key.to_string().trim().to_lowercase(),
80 value.trim().to_string()
81 );
82
83 prev_idx = idx + 1;
84 }
85
86 headers
87 }
88
89 #[inline(always)]
91 fn _parse_request_line(payload: &BytesMut, pos_n: &[usize]) -> HttpResult<(Method, Url, u8)> {
92 let mut request_line = payload[..pos_n[0]-1].split(|c| *c == 0x20);
93
94 let method = match request_line.next().unwrap() {
95 b"GET" => Method::GET,
96 b"POST" => Method::POST,
97 _ => unimplemented!(),
98 };
99
100 let path = str::from_utf8(request_line.next().unwrap())?;
101 let url = Url::parse("https://hawk.local")?;
102 let url = url.join(&path)?;
103
104 let version = match request_line.next().unwrap() {
105 b"HTTP/1.1" => 1,
106 _ => unimplemented!(),
107 };
108
109 Ok((method, url, version))
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 extern crate test;
116
117 use super::*;
118 use bytes::BytesMut;
119 use crate::Method;
120 use test::Bencher;
121
122 #[test]
123 fn for_websocket() {
124 let mut headers = HashMap::new();
125 headers.insert("upgrade".to_string(), "websocket".to_string());
126 headers.insert("connection".to_string(), "Upgrade".to_string());
127
128 let mut req = Request {
129 url: Url::parse("https://hawk.local").unwrap(),
130 method: Method::GET,
131 version: 1,
132 headers,
133 };
134
135 assert!(req.for_websocket());
136
137 req.headers.remove("upgrade");
138 assert!(!req.for_websocket());
139 }
140
141 #[test]
142 fn parse_request_line() {
143 let payload = BytesMut::from(&b"GET /api/entity HTTP/1.1\r\nHost: hawk.local\r\n\r\n"[..]);
144 let pos_n = vec![25, 43, 45];
145 let (method, url, version) = Request::_parse_request_line(&payload, &pos_n).unwrap();
146
147 assert_eq!(method, Method::GET);
148 assert_eq!(url.as_str(), "https://hawk.local/api/entity");
149 assert_eq!(version, 1);
150 }
151
152 #[test]
153 fn parse_headers() {
154 let payload = BytesMut::from(&b"GET /api/entity HTTP/1.1\r\nHost: hawk.local\r\n\r\n"[..]);
155 let pos_n = vec![25, 43, 45];
156 let headers = Request::_parse_headers(&payload, &pos_n);
157
158 assert_eq!("hawk.local", headers.get("host").unwrap());
159 }
160
161 #[bench]
162 fn bench_parse_headers(b: &mut Bencher) {
163 b.iter(|| parse_headers())
165 }
166
167 #[bench]
168 fn bench_parse_request_line(b: &mut Bencher) {
169 b.iter(|| parse_request_line())
171 }
172}