use std::{
collections::BTreeMap,
fmt::Display,
io::{self, Empty, Read},
};
use chrono::{DateTime, Local};
use contracts::*;
use pest::{iterators::Pair, Parser as PestParser};
use pest_derive::Parser as PestDeriveParser;
use tracing::{trace, warn};
use crate::errors::{ParsingError, Result, SevaError};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Request<'a> {
pub method: HttpMethod,
pub path: &'a str,
pub headers: BTreeMap<HeaderName, &'a str>,
pub version: &'a str,
pub time: DateTime<Local>,
}
impl<'a> Request<'a> {
pub fn is_partial(&self) -> bool {
self.headers.contains_key(&HeaderName::Range)
}
pub fn parse(req_str: &str) -> Result<Request> {
trace!("Request::parse");
let mut res = HttpParser::parse(Rule::request, req_str)
.map_err(|e| ParsingError::PestRuleError(format!("{e:?}")))?;
let req_rule = res.next().unwrap();
Request::try_from(req_rule)
}
fn parse_headers(pair: Pair<'a, Rule>) -> Result<BTreeMap<HeaderName, &'a str>> {
trace!("Request::parse_headers");
let mut headers = BTreeMap::new();
for hdr in pair.into_inner() {
let mut hdr = hdr.into_inner();
let hdr_name_opt = hdr.next().unwrap().as_str();
if let Some(name) = HeaderName::from_str(hdr_name_opt) {
let value = hdr.next().unwrap().as_str();
headers.insert(name, value);
} else {
warn!("ignored unknown header: {hdr_name_opt}")
}
}
Ok(headers)
}
}
impl<'i> TryFrom<Pair<'i, Rule>> for Request<'i> {
type Error = SevaError;
fn try_from(
pair: Pair<'i, Rule>,
) -> std::prelude::v1::Result<Self, Self::Error> {
let mut iterator = pair.into_inner();
let method = iterator.next().unwrap().try_into()?;
let path = iterator.next().unwrap().as_str();
let version = iterator.next().unwrap().as_str();
let headers = match iterator.next() {
Some(rule) => Request::parse_headers(rule)?,
None => BTreeMap::new(),
};
let req = Self {
method,
path,
version,
headers,
time: Local::now(),
};
Ok(req)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Response<B>
where
B: Read,
{
pub status: StatusCode,
pub headers: BTreeMap<HeaderName, String>,
pub body: B,
}
impl<B> Response<B>
where
B: Read,
{
pub fn new(
status: StatusCode,
headers: BTreeMap<HeaderName, String>,
body: B,
) -> Response<B> {
Self {
status,
headers,
body,
}
}
}
#[derive(Debug, PartialEq)]
pub struct ResponseBuilder<B> {
status: StatusCode,
headers: BTreeMap<HeaderName, String>,
body: B,
}
impl ResponseBuilder<Empty> {
pub fn new(
status: StatusCode,
headers: BTreeMap<HeaderName, String>,
) -> ResponseBuilder<Empty> {
Self {
status,
headers,
body: io::empty(),
}
}
pub fn ok() -> ResponseBuilder<Empty> {
Self::new(StatusCode::Ok, BTreeMap::new())
}
pub fn partial() -> ResponseBuilder<Empty> {
Self::new(StatusCode::PartialContent, BTreeMap::new())
}
pub fn not_found() -> ResponseBuilder<Empty> {
Self::new(StatusCode::NotFound, BTreeMap::new())
}
pub fn method_not_allowed() -> ResponseBuilder<Empty> {
Self::new(StatusCode::MethodNotAllowed, BTreeMap::new())
}
#[debug_ensures(ret.headers.len() == 1)]
pub fn redirect(location: &str) -> ResponseBuilder<Empty> {
let mut headers = BTreeMap::new();
headers.insert(HeaderName::Location, location.to_owned());
Self::new(StatusCode::MovedPermanently, headers)
}
pub fn body<B: Read>(&self, body: B) -> ResponseBuilder<B> {
ResponseBuilder {
status: self.status,
headers: self.headers.clone(),
body,
}
}
}
impl<B> ResponseBuilder<B>
where
B: Read,
{
#[allow(unused)]
pub fn header(&mut self, name: HeaderName, val: &str) -> &mut Self {
self.headers.insert(name, val.to_owned());
self
}
pub fn headers(
&mut self,
hdrs: impl IntoIterator<Item = (HeaderName, String)>,
) -> &mut Self {
self.headers.extend(hdrs);
self
}
#[allow(unused)]
pub fn status(&mut self, status: StatusCode) -> &mut Self {
self.status = status;
self
}
pub fn build(self) -> Response<B> {
Response::new(self.status, self.headers, self.body)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum HttpMethod {
Connect,
Delete,
Get,
Head,
Options,
Patch,
Post,
Put,
Trace,
}
impl<'i> TryFrom<Pair<'i, Rule>> for HttpMethod {
type Error = ParsingError;
fn try_from(
value: Pair<'i, Rule>,
) -> std::prelude::v1::Result<Self, Self::Error> {
Self::try_from(value.as_str().as_bytes())
}
}
impl TryFrom<&[u8]> for HttpMethod {
type Error = ParsingError;
fn try_from(value: &[u8]) -> std::prelude::v1::Result<Self, Self::Error> {
match value {
b"CONNECT" => Ok(HttpMethod::Connect),
b"DELETE" => Ok(HttpMethod::Delete),
b"GET" => Ok(HttpMethod::Get),
b"HEAD" => Ok(HttpMethod::Head),
b"OPTIONS" => Ok(HttpMethod::Options),
b"PATCH" => Ok(HttpMethod::Patch),
b"POST" => Ok(HttpMethod::Post),
b"PUT" => Ok(HttpMethod::Put),
b"TRACE" => Ok(HttpMethod::Trace),
_ => Err(ParsingError::UnknownMethod(
String::from_utf8(value.to_vec()).unwrap_or_default(),
)),
}
}
}
impl From<HttpMethod> for String {
fn from(value: HttpMethod) -> Self {
let s = match value {
HttpMethod::Connect => "connect",
HttpMethod::Delete => "delete",
HttpMethod::Get => "get",
HttpMethod::Head => "head",
HttpMethod::Options => "options",
HttpMethod::Patch => "patch",
HttpMethod::Post => "post",
HttpMethod::Put => "put",
HttpMethod::Trace => "trace",
};
s.to_uppercase().to_string()
}
}
impl Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&String::from(*self))
}
}
#[derive(PestDeriveParser)]
#[grammar_inline = r#"
request = { request_line ~ headers? ~ NEWLINE }
request_line = _{ method ~ " "+ ~ uri ~ " "+ ~ "HTTP/" ~ version ~ NEWLINE }
uri = { (!whitespace ~ ANY)+ }
method = { ("CONNECT" | "DELETE" | "GET" | "HEAD" | "OPTIONS" | "PATCH" | "POST" | "PUT" | "TRACE") }
version = { (ASCII_DIGIT | ".")+ }
whitespace = _{ " " | "\t" }
headers = { header+ }
header = { header_name ~ ":" ~ whitespace ~ header_value ~ NEWLINE }
header_name = { (!(NEWLINE | ":") ~ ANY)+ }
header_value = { (!NEWLINE ~ ANY)+ }
// accept-encoding header parser
ws = _{( " " | "\t")*}
accept_encoding = { encoding ~ ws ~ ("," ~ ws ~ encoding)* ~ EOI}
algo = {(ASCII_ALPHA+ | "identity" | "*")}
weight = {ws ~ ";" ~ ws ~ "q=" ~ qvalue}
qvalue = { ("0" ~ ("." ~ ASCII_DIGIT{,3}){,1}) | ("1" ~ ("." ~ "0"{,3}){,1}) }
encoding = { algo ~ weight*}
// Range header parser
//
// A range request can specify a single range or a set of ranges within a single representation.
//
// Range = ranges-specifier
// ranges-specifier = range-unit "=" range-set
// range-unit = token
// range-set = 1#range-spec
// range-spec = int-range / suffix-range / other-range
// int-range = first-pos "-" [ last-pos ]
// first-pos = 1*DIGIT
// last-pos = 1*DIGIT
// suffix-range = "-" suffix-length
// suffix-length = 1*DIGIT
// other-range = 1*( %x21-2B / %x2D-7E ) ; 1*(VCHAR excluding comma)
//
bytes_range = { "bytes" ~ ws ~ "=" ~ ws ~ range_sets }
range_sets = _{ range_set ~ ws ~ ("," ~ ws ~ range_set)* ~ EOI }
range_set = _{(int_range | suffix_range)}
int_range = { first_pos ~ "-" ~ last_pos*}
suffix_range = { "-" ~ len}
first_pos = { ASCII_DIGIT+ }
last_pos = { ASCII_DIGIT+ }
len = { ASCII_DIGIT* }
"#]
pub struct HttpParser;
impl HttpParser {
pub fn parse_bytes_range(val: &str, max_len: usize) -> Result<Vec<BytesRange>> {
let br = HttpParser::parse(Rule::bytes_range, val)
.map_err(|e| ParsingError::PestRuleError(format!("{e:?}")))?
.next()
.unwrap();
let mut ranges = vec![];
for pair in br.into_inner() {
match pair.as_rule() {
Rule::int_range => {
let mut inner = pair.into_inner();
let start = inner
.next()
.unwrap()
.as_str()
.parse()
.map_err(ParsingError::IntError)?;
let end = match inner.next() {
Some(r) => {
r.as_str().parse().map_err(ParsingError::IntError)?
}
None => max_len,
};
if start > end {
Err(ParsingError::InvalidRangeHeader(val.to_owned()))?;
}
let size = end - start;
ranges.push(BytesRange { start, size });
}
Rule::suffix_range => {
let mut inner = pair.into_inner();
let size = inner
.next()
.unwrap()
.as_str()
.parse()
.map_err(ParsingError::IntError)?;
if size >= max_len {
Err(ParsingError::InvalidRangeHeader(val.to_owned()))?;
}
let start = max_len - size;
ranges.push(BytesRange { start, size });
}
_ => {}
}
}
if ranges.len() > 10 {
return Err(ParsingError::InvalidRangeHeader(val.to_owned()))?;
}
Ok(ranges)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct BytesRange {
pub start: usize,
pub size: usize,
}
macro_rules! status_codes {
(
$(
$(#[$docs:meta])+
($name:ident, $code:literal);
)+
) => {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Copy)]
#[allow(unused)]
pub enum StatusCode {
$(
$(#[$docs])*
$name,
)+
}
impl StatusCode {
pub fn as_u16(&self) -> u16 {
match *self {
$(
StatusCode::$name => $code,
)+
}
}
fn as_string(&self) -> String {
match *self {
$(
StatusCode::$name => Self::split_name(stringify!($name)),
)+
}
}
fn split_name(name:&str) -> String {
let mut parts = vec!();
let mut cur = String::new();
for ch in name.chars() {
if ch.is_uppercase() && !cur.is_empty() {
parts.push(cur.clone());
cur.clear();
}
cur.push(ch);
}
parts.push(cur);
parts.join(" ").to_uppercase()
}
}
impl std::fmt::Display for StatusCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_string())
}
}
};
}
status_codes! {
(SwitchingProtocols,101);
(Ok, 200);
(NoContent,204);
(PartialContent,206);
(MovedPermanently,301);
(NotModified,304);
(BadRequest,400);
(Forbidden,403);
(NotFound,404);
(MethodNotAllowed,405);
(PayloadTooLarge,413);
(UriTooLong, 414);
(RequestTimeout,408);
(RangeNotSatisifiable, 416);
(TooManyRequests,429);
(InternalServerError,500);
(NotImplemented, 501);
(HttpVersionNotSupported, 505);
(NotExtended,510);
}
macro_rules! header_names {
(
$(
$(#[$docs:meta])+
($hname:ident, $name_str:literal);
)+
) => {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Copy)]
pub enum HeaderName {
$(
$(#[$docs])*
$hname,
)+
}
impl HeaderName {
pub fn as_str(&self) -> &str {
match *self {
$(
HeaderName::$hname => $name_str,
)+
}
}
pub fn from_str(s: &str) -> Option<HeaderName> {
match s.to_lowercase().as_str().trim() {
$(
$name_str => Some(HeaderName::$hname),
)+
_ => None
}
}
}
impl std::fmt::Display for HeaderName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
};
}
header_names! {
(Accept, "accept");
(AcceptEncoding, "accept-encoding");
(AcceptLanguage, "accept-language");
(AcceptRanges, "accept-ranges");
(Allow, "allow");
(CacheControl, "cache-control");
(Connection, "connection");
(ContentDisposition, "content-disposition");
(ContentEncoding, "content-encoding");
(ContentLength, "content-length");
(ContentRange, "content-range");
(ContentType, "content-type");
(Date, "date");
(ETag, "etag");
(Host, "host");
(IfModifiedSince, "if-modified-since");
(IfUnmodifiedSince, "if-unmodified-since");
(LastModified, "last-modified");
(Location, "location");
(Range, "range");
(Referer, "referer");
(Server, "server");
(UserAgent, "user-agent");
(Vary, "vary");
(Warning, "warning");
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use maplit::btreemap;
use super::*;
#[test]
fn request_parsing() -> Result<()> {
let req_str =
"GET / HTTP/1.1\r\nHost: developer.mozilla.org\r\nAccept-Language: fr\r\n\r\n";
let parsed: Request = Request::parse(req_str)?;
let expected = Request {
method: HttpMethod::Get,
path: "/",
headers: btreemap! {
HeaderName::AcceptLanguage => "fr",
HeaderName::Host => "developer.mozilla.org",
},
version: "1.1",
time: Local::now(),
};
assert_eq!(parsed.method, expected.method);
assert_eq!(parsed.path, expected.path);
assert_eq!(parsed.version, expected.version);
assert_eq!(parsed.headers, expected.headers);
Ok(())
}
#[test]
fn accept_encoding_parser() -> Result<()> {
let val = "compress;q=0.5, gzip";
let res = HttpParser::parse(Rule::accept_encoding, val);
assert!(res.is_ok());
Ok(())
}
#[test]
fn bytes_range_parser() -> Result<()> {
for val in [
"bytes=0-499",
"bytes=500-999",
"bytes=-500",
"bytes=9500-",
"bytes=0-0,-1",
"bytes= 0-0, -2",
"bytes= 0-999, 4500-5499, -1000",
"bytes=500-600,601-999",
"bytes=500-700,601-999",
] {
let range = HttpParser::parse_bytes_range(val, 10000);
assert!(range.is_ok(), "failed to parse: {val}. Reason: {range:?}");
}
Ok(())
}
#[test]
fn response_body_type_mapping() -> Result<()> {
let builder = ResponseBuilder::ok();
let builder = builder.body(Cursor::new(vec![]));
let expected = ResponseBuilder {
status: StatusCode::Ok,
headers: BTreeMap::new(),
body: Cursor::new(vec![]),
};
assert_eq!(builder, expected);
Ok(())
}
}