use cookie::Cookie;
use mime::Mime;
use serde::{Deserialize, de::DeserializeOwned};
use hyper::{HeaderMap, body::Bytes, header, http::request::Parts};
use time::Duration;
use std::{collections::HashMap, sync::Arc};
use rand::{Rng, distributions::Alphanumeric};
use async_trait::async_trait;
use crate::response::{SputnikHeaders, delete_cookie};
const CSRF_COOKIE_NAME : &str = "csrf";
pub trait SputnikParts {
fn query<X: DeserializeOwned>(&self) -> Result<X,QueryError>;
fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>>;
fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>;
fn response_headers(&mut self) -> &mut HeaderMap;
}
impl SputnikParts for hyper::http::request::Parts {
fn query<T: DeserializeOwned>(&self) -> Result<T,QueryError> {
serde_urlencoded::from_str::<T>(self.uri.query().unwrap_or("")).map_err(QueryError)
}
fn response_headers(&mut self) -> &mut HeaderMap {
if self.extensions.get::<HeaderMap>().is_none() {
self.extensions.insert(HeaderMap::new());
}
self.extensions.get_mut::<HeaderMap>().unwrap()
}
fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>> {
let cookies: Option<&Arc<HashMap<String, Cookie>>> = self.extensions.get();
if let Some(cookies) = cookies {
return cookies.clone();
}
let mut cookies = HashMap::new();
for header in self.headers.get_all(header::COOKIE) {
let raw_str = match std::str::from_utf8(header.as_bytes()) {
Ok(string) => string,
Err(_) => continue
};
for cookie_str in raw_str.split(';').map(|s| s.trim()) {
if let Ok(cookie) = Cookie::parse_encoded(cookie_str) {
cookies.insert(cookie.name().to_string(), cookie.into_owned());
}
}
}
let cookies = Arc::new(cookies);
self.extensions.insert(cookies.clone());
cookies
}
fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> {
if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) {
if *content_type == mime.to_string() {
return Ok(())
}
}
Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))})
}
}
#[derive(Clone)]
pub struct CsrfToken(String);
impl std::fmt::Display for CsrfToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
const FLASH_COOKIE_NAME: &str = "flash";
pub struct Flash {
name: String,
message: String,
}
impl From<Flash> for Cookie<'_> {
fn from(flash: Flash) -> Self {
Cookie::build(FLASH_COOKIE_NAME, format!("{}:{}", flash.name, flash.message))
.max_age(Duration::minutes(5)).finish()
}
}
impl Flash {
pub fn from_request(req: &mut Parts) -> Option<Self> {
req.cookies().get(FLASH_COOKIE_NAME)
.and_then(|cookie| {
req.response_headers().set_cookie(delete_cookie(FLASH_COOKIE_NAME));
let mut iter = cookie.value().splitn(2, ':');
if let (Some(name), Some(message)) = (iter.next(), iter.next()) {
return Some(Flash{name: name.to_owned(), message: message.to_owned()})
}
None
})
}
pub fn new(name: String, message: String) -> Self {
Flash{name, message}
}
pub fn success(message: String) -> Self {
Flash{name: "success".to_owned(), message}
}
pub fn warning(message: String) -> Self {
Flash{name: "warning".to_owned(), message}
}
pub fn error(message: String) -> Self {
Flash{name: "error".to_owned(), message}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn message(&self) -> &str {
&self.message
}
}
impl CsrfToken {
pub fn from_request(req: &mut Parts) -> Self {
if let Some(token) = req.extensions.get::<CsrfToken>() {
return token.clone()
}
csrf_token_from_cookies(req)
.unwrap_or_else(|| {
let token: String = rand::thread_rng().sample_iter(
Alphanumeric
).take(16).collect();
let mut c = Cookie::new(CSRF_COOKIE_NAME, token.clone());
c.set_secure(Some(true));
c.set_max_age(Some(Duration::hours(1)));
req.response_headers().set_cookie(c);
let token = CsrfToken(token);
req.extensions.insert(token.clone());
token
})
}
pub fn html_input(&self) -> String {
format!("<input name=csrf type=hidden value=\"{}\">", self)
}
}
#[async_trait]
pub trait SputnikBody {
async fn into_bytes(self) -> Result<Bytes, BodyError>;
async fn into_form<T: DeserializeOwned>(self) -> Result<T, FormError>;
async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError>;
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
async fn into_json<T: DeserializeOwned>(self) -> Result<T, JsonError>;
}
fn csrf_token_from_cookies(req: &mut Parts) -> Option<CsrfToken> {
req.cookies()
.get(CSRF_COOKIE_NAME)
.map(|cookie| {
let token = CsrfToken(cookie.value().to_string());
req.extensions.insert(token.clone());
token
})
}
#[async_trait]
impl SputnikBody for hyper::Body {
async fn into_bytes(self) -> Result<Bytes, BodyError> {
hyper::body::to_bytes(self).await.map_err(BodyError)
}
async fn into_form<T: DeserializeOwned>(self) -> Result<T, FormError> {
let full_body = self.into_bytes().await?;
Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
}
async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError> {
let full_body = self.into_bytes().await?;
let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?;
match csrf_token_from_cookies(req) {
Some(token) => if token.to_string() == csrf_data.csrf {
Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
} else {
Err(CsrfProtectedFormError::Mismatch)
}
None => Err(CsrfProtectedFormError::NoCookie)
}
}
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
async fn into_json<T: DeserializeOwned>(self) -> Result<T, JsonError> {
let full_body = self.into_bytes().await?;
Ok(serde_json::from_slice::<T>(&full_body)?)
}
}
#[derive(Deserialize)]
struct CsrfData {
csrf: String,
}
#[derive(thiserror::Error, Debug)]
#[error("query deserialize error: {0}")]
pub struct QueryError(pub serde_urlencoded::de::Error);
#[derive(thiserror::Error, Debug)]
#[error("failed to read body")]
pub struct BodyError(pub hyper::Error);
#[derive(thiserror::Error, Debug)]
#[error("expected Content-Type {expected} but received {}", received.as_ref().unwrap_or(&"nothing".to_owned()))]
pub struct WrongContentTypeError {
pub expected: Mime,
pub received: Option<String>,
}
#[derive(thiserror::Error, Debug)]
pub enum FormError {
#[error("{0}")]
Body(#[from] BodyError),
#[error("form deserialize error: {0}")]
Deserialize(#[from] serde_urlencoded::de::Error),
}
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
#[derive(thiserror::Error, Debug)]
pub enum JsonError {
#[error("{0}")]
Body(#[from] BodyError),
#[error("json deserialize error: {0}")]
Deserialize(#[from] serde_json::Error),
}
#[derive(thiserror::Error, Debug)]
pub enum CsrfProtectedFormError {
#[error("{0}")]
Body(#[from] BodyError),
#[error("form deserialize error: {0}")]
Deserialize(#[from] serde_urlencoded::de::Error),
#[error("no csrf token in form data")]
NoCsrf,
#[error("no csrf cookie")]
NoCookie,
#[error("csrf parameter doesn't match csrf cookie")]
Mismatch,
}
#[cfg(test)]
mod tests {
use hyper::Request;
use super::*;
#[test]
fn test_csrf_token() {
let mut parts = Request::new(hyper::Body::empty()).into_parts().0;
let tok1 = CsrfToken::from_request(&mut parts);
let tok2 = CsrfToken::from_request(&mut parts);
assert_eq!(tok1.to_string(), tok2.to_string());
assert_eq!(parts.response_headers().len(), 1);
}
}