use cookie::Cookie;
use header::CONTENT_TYPE;
use mime::{APPLICATION_WWW_FORM_URLENCODED, Mime};
use serde::{Deserialize, de::DeserializeOwned};
use hyper::{body::Bytes, header};
use hyper::http::request::Parts as ReqParts;
use std::collections::HashMap;
use crate::security;
use error::*;
type HyperRequest = hyper::Request<hyper::Body>;
pub struct Body {
body: hyper::Body,
content_type: Option<header::HeaderValue>,
}
pub fn adapt<'a>(req: HyperRequest) -> (Parts, Body) {
let (parts, body) = req.into_parts();
let body = Body{body, content_type: parts.headers.get(CONTENT_TYPE).map(|x| x.to_owned())};
let parts = Parts{parts, cookies: None};
(parts, body)
}
pub struct Parts {
parts: ReqParts,
cookies: Option<HashMap<String,Cookie<'static>>>,
}
#[derive(Deserialize)]
struct CsrfData {
csrf: String,
}
impl Parts {
pub fn cookies(&mut self) -> &HashMap<String,Cookie> {
if let Some(ref cookies) = self.cookies {
return cookies
}
let mut cookies = HashMap::new();
for header in self.parts.headers.get_all(header::COOKIE) {
let raw_str = match std::str::from_utf8(header.as_bytes()) {
Ok(string) => string,
Err(_) => continue
};
for cookie_str in raw_str.split(';').map(|s| s.trim()) {
if let Ok(cookie) = Cookie::parse_encoded(cookie_str) {
cookies.insert(cookie.name().to_string(), cookie.into_owned());
}
}
}
self.cookies = Some(cookies);
&self.cookies.as_ref().unwrap()
}
pub fn method(&self) -> &hyper::Method {
&self.parts.method
}
pub fn headers(&self) -> &hyper::HeaderMap<header::HeaderValue> {
&self.parts.headers
}
pub fn uri(&self) -> &hyper::Uri {
&self.parts.uri
}
pub fn query<T: DeserializeOwned>(&self) -> Result<T,QueryError> {
serde_urlencoded::from_str::<T>(self.parts.uri.query().unwrap_or("")).map_err(QueryError)
}
}
impl Body {
pub async fn into_bytes(self) -> Result<Bytes, BodyError> {
hyper::body::to_bytes(self.body).await.map_err(BodyError)
}
pub async fn into_form<T: DeserializeOwned>(self) -> Result<T, FormError> {
self.enforce_content_type(APPLICATION_WWW_FORM_URLENCODED)?;
let full_body = self.into_bytes().await?;
serde_urlencoded::from_bytes::<T>(&full_body).map_err(FormError::Deserialize)
}
pub async fn into_form_csrf<T: DeserializeOwned>(self, csrf_token: &security::CsrfToken) -> Result<T, CsrfProtectedFormError> {
self.enforce_content_type(APPLICATION_WWW_FORM_URLENCODED)?;
let full_body = self.into_bytes().await?;
let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?;
csrf_token.matches(csrf_data.csrf)?;
serde_urlencoded::from_bytes::<T>(&full_body).map_err(CsrfProtectedFormError::Deserialize)
}
fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> {
if let Some(content_type) = &self.content_type {
if *content_type == mime.to_string() {
return Ok(())
}
}
Err(WrongContentTypeError{expected: mime, received: self.content_type.as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))})
}
}
pub mod error {
use mime::Mime;
use thiserror::Error;
use hyper::StatusCode;
use crate::security::CsrfError;
#[derive(Error, Debug)]
#[error("query deserialize error: {0}")]
pub struct QueryError(pub serde_urlencoded::de::Error);
impl_into_error_simple!(QueryError, StatusCode::BAD_REQUEST);
#[derive(Error, Debug)]
#[error("failed to read body")]
pub struct BodyError(pub hyper::Error);
impl_into_error_simple!(BodyError, StatusCode::BAD_REQUEST);
#[derive(Error, Debug)]
#[error("expected Content-Type {expected} but received {}", received.as_ref().unwrap_or(&"nothing".to_owned()))]
pub struct WrongContentTypeError {
pub expected: Mime,
pub received: Option<String>,
}
#[derive(Error, Debug)]
pub enum FormError {
#[error("{0}")]
ContentType(#[from] WrongContentTypeError),
#[error("{0}")]
Body(#[from] BodyError),
#[error("form deserialize error: {0}")]
Deserialize(#[from] serde_urlencoded::de::Error),
}
impl_into_error_simple!(FormError, StatusCode::BAD_REQUEST);
#[derive(Error, Debug)]
pub enum CsrfProtectedFormError {
#[error("{0}")]
ContentType(#[from] WrongContentTypeError),
#[error("{0}")]
Body(#[from] BodyError),
#[error("form deserialize error: {0}")]
Deserialize(#[from] serde_urlencoded::de::Error),
#[error("no csrf token in form data")]
NoCsrf,
#[error("{0}")]
Csrf(#[from] CsrfError),
}
impl_into_error_simple!(CsrfProtectedFormError, StatusCode::BAD_REQUEST);
}