sputnik 0.1.1

A lightweight layer on top of hyper to facilitate building web applications.
Documentation
//! A lightweight layer on top of [Hyper](https://hyper.rs/)
//! to facilitate building web applications.
use std::collections::HashMap;

pub use error::Error;
pub use mime;
use cookie::Cookie;
use header::HeaderName;
use mime::{APPLICATION_WWW_FORM_URLENCODED, Mime};
use serde::{Deserialize, de::DeserializeOwned};
use hyper::{Body, StatusCode, body::Bytes, header::{self, HeaderValue}, http::request::Parts};
use time::{Duration, OffsetDateTime};

pub use httpdate;

type HyperRequest = hyper::Request<Body>;
type HyperResponse = hyper::Response<Body>;

pub mod security;
mod error;
mod signed;

/// Convenience wrapper around [`hyper::Request`].
pub struct Request {
    body: Body,
    parts: Parts,
    cookies: Option<HashMap<String,Cookie<'static>>>,
}

impl From<HyperRequest> for Request {
    fn from(req: HyperRequest) -> Self {
        let (parts, body) = req.into_parts();
        Request{parts, body, cookies: None}
    }
}

impl Into<HyperResponse> for Response {
    fn into(self) -> HyperResponse {
        self.res
    }
}

fn enforce_content_type(req: &Parts, mime: Mime) -> Result<(),Error> {
    let received_type = req.headers.get(header::CONTENT_TYPE).ok_or(Error::bad_request(format!("expected content-type: {}", mime)))?;
    if *received_type != mime.to_string() {
        return Err(Error::bad_request(format!("expected content-type: {}", mime)))
    }
    Ok(())
}

#[derive(Deserialize)]
struct CsrfData {
    csrf: String,
}

impl Request {
    pub fn cookies(&mut self) -> &HashMap<String,Cookie> {
        if let Some(ref cookies) = self.cookies {
            return cookies
        }
        let mut cookies = HashMap::new();
        for header in self.parts.headers.get_all(header::COOKIE) {
            let raw_str = match std::str::from_utf8(header.as_bytes()) {
                Ok(string) => string,
                Err(_) => continue
            };

            for cookie_str in raw_str.split(';').map(|s| s.trim()) {
                if let Ok(cookie) = Cookie::parse_encoded(cookie_str) {
                    cookies.insert(cookie.name().to_string(), cookie.into_owned());
                }
            }
        }
        self.cookies = Some(cookies);
        &self.cookies.as_ref().unwrap()
    }

    pub fn method(&self) -> &hyper::Method {
        &self.parts.method
    }

    pub fn uri(&self) -> &hyper::Uri {
        &self.parts.uri
    }

    /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
    ///
    /// This does make you vulnerable to CSRF so you normally want to use
    /// [`parse_form_csrf()`] instead.
    ///
    /// # Example
    ///
    /// ```
    /// use hyper::{Response, Body};
    /// use sputnik::{Request, Error};
    /// use serde::Deserialize;
    ///
    /// #[derive(Deserialize)]
    /// struct Message {text: String, year: i64}
    ///
    /// async fn greet(req: &mut Request) -> Result<Response<Body>, Error> {
    ///     let msg: Message = req.into_form().await?;
    ///     Ok(Response::new(format!("hello {}", msg.text).into()))
    /// }
    /// ```
    pub async fn into_form<T: DeserializeOwned>(&mut self) -> Result<T,Error> {
        enforce_content_type(&self.parts, APPLICATION_WWW_FORM_URLENCODED)?;
        let full_body = self.into_body().await?;
        serde_urlencoded::from_bytes::<T>(&full_body).map_err(|e|Error::bad_request(e.to_string()))
    }

    /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
    /// Protects from CSRF by checking that the request body contains the same token retrieved from the cookies.
    ///
    /// The CSRF parameter is expected as the `csrf` parameter in the request body.
    /// This means for HTML forms you need to embed the token as a hidden input.
    ///
    /// # Example
    ///
    /// ```
    /// use hyper::{Method};
    /// use sputnik::{Request, Response, Error};
    /// use sputnik::security::CsrfToken;
    /// use serde::Deserialize;
    ///
    /// #[derive(Deserialize)]
    /// struct Message {text: String}
    ///
    /// async fn greet(req: &mut Request) -> Result<Response, Error> {
    ///     let mut response = Response::new();
    ///     let csrf_token = CsrfToken::from_request(req, &mut response);
    ///     *response.body() = match (req.method()) {
    ///         &Method::GET => format!("<form method=post>
    ///             <input name=text>{}<button>Submit</button></form>", csrf_token.html_input()).into(),
    ///         &Method::POST => {
    ///             let msg: Message = req.into_form_csrf(&csrf_token).await?;
    ///             format!("hello {}", msg.text).into()
    ///         },
    ///         _ => return Err(Error::method_not_allowed("only GET and POST allowed".to_owned())),
    ///     };
    ///     Ok(response)
    /// }
    /// ```
    pub async fn into_form_csrf<T: DeserializeOwned>(&mut self, csrf_token: &security::CsrfToken) -> Result<T,Error> {
        enforce_content_type(&self.parts, APPLICATION_WWW_FORM_URLENCODED)?;
        let full_body = self.into_body().await?;
        let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_|Error::bad_request("no csrf token".to_string()))?;
        csrf_token.matches(csrf_data.csrf)?;
        serde_urlencoded::from_bytes::<T>(&full_body).map_err(|e|Error::bad_request(e.to_string()))
    }

    pub async fn into_body(&mut self) -> Result<Bytes,Error> {
        hyper::body::to_bytes(&mut self.body).await.map_err(|_|Error::internal("failed to read body".to_string()))
    }

    /// Parses the query string of the request into a given struct.
    pub fn query<T: DeserializeOwned>(&self) -> Result<T,Error> {
        serde_urlencoded::from_str::<T>(self.parts.uri.query().unwrap_or("")).map_err(|e|Error::bad_request(e.to_string()))
    }
}

/// Convenience wrapper around [`hyper::Response`].
pub struct Response {
    res: HyperResponse
}

impl Response {
    pub fn new() -> Self {
        Response{res: HyperResponse::new(Body::empty())}
    }

    pub fn status(&mut self) -> &mut StatusCode {
        self.res.status_mut()
    }

    pub fn body(&mut self) -> &mut Body {
        self.res.body_mut()
    }

    pub fn headers(&mut self) -> &mut hyper::HeaderMap<header::HeaderValue> {
        self.res.headers_mut()
    }

    pub fn set_header<S: AsRef<str>>(&mut self, header: HeaderName, value: S) {
        self.res.headers_mut().insert(header, HeaderValue::from_str(value.as_ref()).unwrap());
    }

    pub fn set_content_type(&mut self, mime: mime::Mime) {
        self.res.headers_mut().insert(header::CONTENT_TYPE, mime.to_string().parse().unwrap());
    }

    pub fn redirect<S: AsRef<str>>(&mut self, location: S, code: StatusCode) {
        *self.res.status_mut() = code;
        self.set_header(header::LOCATION, location);
    }

    pub fn set_cookie(&mut self, cookie: Cookie) {
        self.res.headers_mut().append(header::SET_COOKIE, cookie.encoded().to_string().parse().unwrap());
    }

    pub fn delete_cookie(&mut self, name: &str) {
        let mut cookie = Cookie::new(name, "");
        cookie.set_max_age(Duration::seconds(0));
        cookie.set_expires(OffsetDateTime::now_utc() - Duration::days(365));
        self.set_cookie(cookie);
    }
}