use std::str;
use std::collections::HashMap;
use bytes::BytesMut;
use url::Url;
use crate::{
Method,
HttpResult,
};
#[derive(Debug)]
pub struct Request {
pub url: Url,
pub method: Method,
pub version: u8,
pub headers: HashMap<String, String>,
}
impl Request {
pub fn new(payload: &BytesMut, pos_n: &[usize]) -> Self {
let (method, url, version) = Self::_parse_request_line(&payload, &pos_n).unwrap();
let headers = Self::_parse_headers(&payload, &pos_n);
Self {
url,
method,
version,
headers,
}
}
pub fn for_websocket(&self) -> bool {
if self.version != 1 || self.method != Method::GET {
return false
}
let required_headers = &[("upgrade", "websocket"), ("connection", "Upgrade")];
for (key, expected_value) in required_headers.iter() {
if let Some(value) = self.headers.get(*key) {
if value != expected_value {
return false
}
} else {
return false
}
}
true
}
#[inline(always)]
fn _parse_headers(payload: &BytesMut, pos_n: &[usize]) -> HashMap<String, String> {
let mut headers = HashMap::new();
let mut prev_idx = pos_n[0] + 1;
for idx in pos_n[1..].iter() {
let line = unsafe { String::from_utf8_unchecked(payload[prev_idx..*idx-1].to_vec()) };
if line.is_empty() { break }
let mut split = line.splitn(2, ':');
let key = split.next().unwrap();
let value = split.next().unwrap();
headers.insert(
key.to_string().trim().to_lowercase(),
value.trim().to_string()
);
prev_idx = idx + 1;
}
headers
}
#[inline(always)]
fn _parse_request_line(payload: &BytesMut, pos_n: &[usize]) -> HttpResult<(Method, Url, u8)> {
let mut request_line = payload[..pos_n[0]-1].split(|c| *c == 0x20);
let method = match request_line.next().unwrap() {
b"GET" => Method::GET,
b"POST" => Method::POST,
_ => unimplemented!(),
};
let path = str::from_utf8(request_line.next().unwrap())?;
let url = Url::parse("https://hawk.local")?;
let url = url.join(&path)?;
let version = match request_line.next().unwrap() {
b"HTTP/1.1" => 1,
_ => unimplemented!(),
};
Ok((method, url, version))
}
}
#[cfg(test)]
mod tests {
extern crate test;
use super::*;
use bytes::BytesMut;
use crate::Method;
use test::Bencher;
#[test]
fn for_websocket() {
let mut headers = HashMap::new();
headers.insert("upgrade".to_string(), "websocket".to_string());
headers.insert("connection".to_string(), "Upgrade".to_string());
let mut req = Request {
url: Url::parse("https://hawk.local").unwrap(),
method: Method::GET,
version: 1,
headers,
};
assert!(req.for_websocket());
req.headers.remove("upgrade");
assert!(!req.for_websocket());
}
#[test]
fn parse_request_line() {
let payload = BytesMut::from(&b"GET /api/entity HTTP/1.1\r\nHost: hawk.local\r\n\r\n"[..]);
let pos_n = vec![25, 43, 45];
let (method, url, version) = Request::_parse_request_line(&payload, &pos_n).unwrap();
assert_eq!(method, Method::GET);
assert_eq!(url.as_str(), "https://hawk.local/api/entity");
assert_eq!(version, 1);
}
#[test]
fn parse_headers() {
let payload = BytesMut::from(&b"GET /api/entity HTTP/1.1\r\nHost: hawk.local\r\n\r\n"[..]);
let pos_n = vec![25, 43, 45];
let headers = Request::_parse_headers(&payload, &pos_n);
assert_eq!("hawk.local", headers.get("host").unwrap());
}
#[bench]
fn bench_parse_headers(b: &mut Bencher) {
b.iter(|| parse_headers())
}
#[bench]
fn bench_parse_request_line(b: &mut Bencher) {
b.iter(|| parse_request_line())
}
}