sputnik 0.3.1

A lightweight layer on top of hyper to facilitate building web applications.
Documentation
//! Provides the [`SputnikParts`] and [`SputnikBody`] traits.

use cookie::Cookie;
use mime::Mime;
use serde::{Deserialize, de::DeserializeOwned};
use hyper::{body::Bytes, header, http::{request::Parts, response::Builder}};
use time::Duration;
use std::{collections::HashMap, sync::Arc};
use rand::{Rng, distributions::Alphanumeric};
use async_trait::async_trait;

use crate::response::SputnikBuilder;

const CSRF_COOKIE_NAME : &str = "csrf";

pub trait SputnikParts {
    /// Parses the query string of the request into a given struct.
    fn query<X: DeserializeOwned>(&self) -> Result<X,QueryError>;

    /// Parses the cookies of the request.
    fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>>;

    /// Enforces a specific Content-Type.
    fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>;

    /// Retrievs the CSRF token from a cookie or generates
    /// a new token and stores it as a cookie if it doesn't exist.
    /// Returns a hidden HTML input to be embedded in forms that are received
    /// with [`crate::request::SputnikBody::into_form_csrf`].
    fn csrf_html_input(&mut self, builder: &mut Builder) -> String;
}

impl SputnikParts for hyper::http::request::Parts {
    fn query<T: DeserializeOwned>(&self) -> Result<T,QueryError> {
        serde_urlencoded::from_str::<T>(self.uri.query().unwrap_or("")).map_err(QueryError)
    }

    fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>> {
        let cookies: Option<&Arc<HashMap<String, Cookie>>> = self.extensions.get();
        if let Some(cookies) = cookies {
            return cookies.clone();
        }

        let mut cookies = HashMap::new();
        for header in self.headers.get_all(header::COOKIE) {
            let raw_str = match std::str::from_utf8(header.as_bytes()) {
                Ok(string) => string,
                Err(_) => continue
            };

            for cookie_str in raw_str.split(';').map(|s| s.trim()) {
                if let Ok(cookie) = Cookie::parse_encoded(cookie_str) {
                    cookies.insert(cookie.name().to_string(), cookie.into_owned());
                }
            }
        }
        let cookies = Arc::new(cookies);
        self.extensions.insert(cookies.clone());
        cookies
    }

    fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> {
        if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) {
            if *content_type == mime.to_string() {
                return Ok(())
            }
        }
        Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))})
    }

    fn csrf_html_input(&mut self, builder: &mut Builder) -> String {
        let token = csrf_token_from_cookies(self).unwrap_or_else(|| {
            let token: String = rand::thread_rng().sample_iter(Alphanumeric).take(16).collect();
            let mut c = Cookie::new(CSRF_COOKIE_NAME, token.clone());
            c.set_secure(Some(true));
            c.set_max_age(Some(Duration::hours(1)));
            builder.set_cookie(c);
            token
        });
        format!("<input name=csrf type=hidden value=\"{}\">", token)
    }
}

#[async_trait]
pub trait SputnikBody {
    async fn into_bytes(self) -> Result<Bytes, BodyError>;

    /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
    ///
    /// This does make you vulnerable to CSRF, so you normally want to use
    /// [`SputnikBody::into_form_csrf()`] instead.
    async fn into_form<T: DeserializeOwned>(self) -> Result<T, FormError>;

    /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
    /// Protects from CSRF by checking that the request body contains the same token retrieved from the cookies.
    ///
    /// The HTML form must embed a hidden input generated with [`crate::request::SputnikParts::csrf_html_input`].
    async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError>;
}

fn csrf_token_from_cookies(req: &mut Parts) -> Option<String> {
    req.cookies().get(CSRF_COOKIE_NAME).map(|c| c.value().to_string())
}

#[async_trait]
impl SputnikBody for hyper::Body {
    async fn into_bytes(self) -> Result<Bytes, BodyError> {
        hyper::body::to_bytes(self).await.map_err(BodyError)
    }

    async fn into_form<T: DeserializeOwned>(self) -> Result<T, FormError> {
        let full_body = self.into_bytes().await?;
        Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
    }

    async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError> {
        let full_body = self.into_bytes().await?;
        let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?;
        match csrf_token_from_cookies(req) {
            Some(token) => if token == csrf_data.csrf {
                Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
            } else {
                Err(CsrfProtectedFormError::Mismatch)
            }
            None => Err(CsrfProtectedFormError::NoCookie)
        }
    }
}

#[derive(Deserialize)]
struct CsrfData {
    csrf: String,
}

#[derive(thiserror::Error, Debug)]
#[error("query deserialize error: {0}")]
pub struct QueryError(pub serde_urlencoded::de::Error);

#[derive(thiserror::Error, Debug)]
#[error("failed to read body")]
pub struct BodyError(pub hyper::Error);

#[derive(thiserror::Error, Debug)]
#[error("expected Content-Type {expected} but received {}", received.as_ref().unwrap_or(&"nothing".to_owned()))]
pub struct WrongContentTypeError {
    pub expected: Mime,
    pub received: Option<String>,
}

#[derive(thiserror::Error, Debug)]
pub enum FormError {
    #[error("{0}")]
    Body(#[from] BodyError),

    #[error("form deserialize error: {0}")]
    Deserialize(#[from] serde_urlencoded::de::Error),
}

#[derive(thiserror::Error, Debug)]
pub enum CsrfProtectedFormError {
    #[error("{0}")]
    Body(#[from] BodyError),

    #[error("form deserialize error: {0}")]
    Deserialize(#[from] serde_urlencoded::de::Error),

    #[error("no csrf token in form data")]
    NoCsrf,

    #[error("no csrf cookie")]
    NoCookie,

    #[error("csrf parameter doesn't match csrf cookie")]
    Mismatch,
}