tophat 0.3.0

A small, pragmatic, and flexible async http server
Documentation
// Cors module based on warp's.

//! Cors module
//!
//! Handles pre-flight
//!
//! Currently a super-simple, not-complete implementation.
//!
//! Does _not_ check for correctness of request headers and content-type.
// (Does anybody? I checked warp and iron cors middleware, I don't think they do.
//!
//! Not yet an ergonomic api. (No builder)
//!
//! ## Simple cors
//! Only checks for client's Origin header, and will respond with a `Access-Control-Allow-Origin`
//! header only, with the specified allowed origins.
//!
//! ## Preflight cors
//! - client method: is `Options`
//! - client header: origin
//! - client header: access-control-request-method
//! - client header: access-control-request-headers
//!
//! - server status: 200 OK
//! - server header: access-control-allow-origin
//! - server header: access-control-allow-methods
//! - server header: access-control-allow-headers
//! - server header: access-control-max-age (86400s is one day)

use futures_util::io::{AsyncRead, AsyncWrite};
use headers::{
    AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMapExt,
    Origin,
};
use http::{
    header::{self, HeaderMap, HeaderName, HeaderValue},
    Method, StatusCode,
};
use std::collections::HashSet;
use std::convert::TryFrom;

use crate::{
    server::{
        glitch::{Glitch, Result},
        ResponseWriter,
    },
    Request,
};

/// Build a Cors
pub struct CorsBuilder {
    /// For preflight and simple, whether to add the access-control-allow-credentials header
    /// default false
    pub credentials: bool,
    /// For preflight only, allowed headers
    pub allowed_headers: HashSet<HeaderName>,
    /// For preflight and simple, tell client what headers it can access
    pub exposed_headers: HashSet<HeaderName>,
    /// For preflight only, max age
    pub max_age: Option<u64>,
    /// For preflight only, allowed methods
    pub methods: HashSet<http::Method>,
    /// For preflight and simple, allowed origins. Default is '*'
    pub origins: Option<HashSet<HeaderValue>>,
}

impl CorsBuilder {
    /// Sets whether to add the `Access-Control-Allow-Credentials` header.
    pub fn allow_credentials(mut self, allow: bool) -> Self {
        self.credentials = allow;
        self
    }

    /// Adds a method to the existing list of allowed request methods.
    ///
    /// # Panics
    ///
    /// Panics if the provided argument is not a valid `http::Method`.
    pub fn allow_method<M>(mut self, method: M) -> Self
    where
        http::Method: TryFrom<M>,
    {
        let method = match TryFrom::try_from(method) {
            Ok(m) => m,
            _ => panic!("illegal Method"),
        };
        self.methods.insert(method);
        self
    }

    /// Adds multiple methods to the existing list of allowed request methods.
    ///
    /// # Panics
    ///
    /// Panics if the provided argument is not a valid `http::Method`.
    pub fn allow_methods<I>(mut self, methods: I) -> Self
    where
        I: IntoIterator,
        http::Method: TryFrom<I::Item>,
    {
        let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
            Ok(m) => m,
            _ => panic!("illegal Method"),
        });
        self.methods.extend(iter);
        self
    }

    /// Adds a header to the list of allowed request headers.
    ///
    /// # Panics
    ///
    /// Panics if the provided argument is not a valid `http::header::HeaderName`.
    pub fn allow_header<H>(mut self, header: H) -> Self
    where
        HeaderName: TryFrom<H>,
    {
        let header = match TryFrom::try_from(header) {
            Ok(m) => m,
            _ => panic!("illegal Header"),
        };
        self.allowed_headers.insert(header);
        self
    }

    /// Adds multiple headers to the list of allowed request headers.
    ///
    /// # Panics
    ///
    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
    pub fn allow_headers<I>(mut self, headers: I) -> Self
    where
        I: IntoIterator,
        HeaderName: TryFrom<I::Item>,
    {
        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
            Ok(h) => h,
            _ => panic!("illegal Header"),
        });
        self.allowed_headers.extend(iter);
        self
    }

    /// Adds a header to the list of exposed headers.
    ///
    /// # Panics
    ///
    /// Panics if the provided argument is not a valid `http::header::HeaderName`.
    pub fn expose_header<H>(mut self, header: H) -> Self
    where
        HeaderName: TryFrom<H>,
    {
        let header = match TryFrom::try_from(header) {
            Ok(m) => m,
            _ => panic!("illegal Header"),
        };
        self.exposed_headers.insert(header);
        self
    }

    /// Adds multiple headers to the list of exposed headers.
    ///
    /// # Panics
    ///
    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
    pub fn expose_headers<I>(mut self, headers: I) -> Self
    where
        I: IntoIterator,
        HeaderName: TryFrom<I::Item>,
    {
        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
            Ok(h) => h,
            _ => panic!("illegal Header"),
        });
        self.exposed_headers.extend(iter);
        self
    }

    /// Sets that *any* `Origin` header is allowed.
    ///
    /// # Warning
    ///
    /// This can allow websites you didn't instead to access this resource,
    /// it is usually better to set an explicit list.
    pub fn allow_any_origin(mut self) -> Self {
        self.origins = None;
        self
    }

    /// Add an origin to the existing list of allowed `Origin`s.
    ///
    /// # Panics
    ///
    /// Panics if the provided argument is not a valid `Origin`.
    pub fn allow_origin(self, origin: impl IntoOrigin) -> Self {
        self.allow_origins(Some(origin))
    }

    /// Add multiple origins to the existing list of allowed `Origin`s.
    ///
    /// # Panics
    ///
    /// Panics if the provided argument is not a valid `Origin`.
    pub fn allow_origins<I>(mut self, origins: I) -> Self
    where
        I: IntoIterator,
        I::Item: IntoOrigin,
    {
        let iter = origins
            .into_iter()
            .map(IntoOrigin::into_origin)
            .map(|origin| {
                origin
                    .to_string()
                    .parse()
                    .expect("Origin is always a valid HeaderValue")
            });

        self.origins.get_or_insert_with(HashSet::new).extend(iter);

        self
    }

    /// Sets the `Access-Control-Max-Age` header.
    pub fn max_age(mut self, seconds: u64) -> Self {
        self.max_age = Some(seconds);
        self
    }

    /// Finish building a Cors
    pub fn finish(self) -> Cors {
        let exposed_headers = if self.exposed_headers.is_empty() {
            None
        } else {
            Some(self.exposed_headers.into_iter().collect())
        };

        Cors {
            credentials: self.credentials,
            allowed_headers: self.allowed_headers.iter().cloned().collect(),
            allowed_headers_set: self.allowed_headers,
            exposed_headers,
            max_age: self.max_age,
            methods: self.methods.iter().cloned().collect(),
            methods_set: self.methods,
            origins: self.origins,
        }
    }
}

/// Cors
///
/// See module docs for more details
#[derive(Clone)]
pub struct Cors {
    /// For preflight and simple, whether to add the access-control-allow-credentials header
    /// default false
    credentials: bool,

    allowed_headers_set: HashSet<HeaderName>,
    /// For preflight only, allowed headers
    allowed_headers: AccessControlAllowHeaders,

    /// For preflight and simple, tell client what headers it can access
    exposed_headers: Option<AccessControlExposeHeaders>,

    /// For preflight only, max age
    max_age: Option<u64>,

    methods_set: HashSet<http::Method>,
    /// For preflight only, allowed methods
    methods: AccessControlAllowMethods,
    /// For preflight and simple, allowed origins. Default is '*'
    /// When responding, just use the origin sent by client if it's in the allowed list.
    origins: Option<HashSet<HeaderValue>>,
}

impl Cors {
    /// Build a Cors
    pub fn build() -> CorsBuilder {
        CorsBuilder {
            credentials: false,
            allowed_headers: HashSet::new(),
            exposed_headers: HashSet::new(),
            max_age: None,
            methods: HashSet::new(),
            origins: None,
        }
    }

    // `Options` method differentiates preflight from simple. Does not check for correctness of a
    // simple request.
    //
    // The design seems a little weird in terms of error handling; basically
    // - Ok means continuing to endpoint. This is for both simple cors and not cors
    // - Err means short-circuit. This is for preflight and invalid
    /// Validate Cors.
    ///
    /// - handles simple Cors
    /// - handles preflight.
    ///
    /// See example `cors` to set up properly as middleware.
    pub fn validate<W>(&self, req: &Request, resp_wtr: &mut ResponseWriter<W>) -> Result<()>
    where
        W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static,
    {
        let req_method = req.method();
        let req_origin = req.headers().get(header::ORIGIN);

        match (req_method, req_origin) {
            (&Method::OPTIONS, Some(origin)) => {
                // Preflight checks
                if !self.is_origin_allowed(origin) {
                    return Err(Glitch::bad_request());
                    // TODO error message?
                    //Err(Forbidden::OriginNotAllowed);
                }

                let headers = req.headers();

                if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
                    if !self.is_method_allowed(req_method) {
                        return Err(Glitch::bad_request());
                        // TODO error message?
                        //Err(Forbidden::MethodNotAllowed);
                    }
                } else {
                    println!("hit");
                    return Err(Glitch::bad_request());
                    // TODO error message?
                    // return Err(Forbidden::MethodNotAllowed);
                }

                if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
                    // TODO error message?
                    //let headers = req.headers()
                    //    .to_str()
                    //    .map_err(|_| Forbidden::HeaderNotAllowed)?;
                    let headers = match req_headers.to_str() {
                        Ok(h) => h,
                        Err(_) => return Err(Glitch::bad_request()),
                    };
                    for header in headers.split(',') {
                        if !self.is_header_allowed(header) {
                            return Err(Glitch::bad_request());
                            // TODO error message?
                            //return Err(Forbidden::HeaderNotAllowed);
                        }
                    }
                }

                // If all checks successful, continue with headers for resp.
                //
                // NOTE it looks kind of weird, but a Glitch is used to have an early return for
                // preflight.
                //
                // set headers
                let mut resp = Glitch::new();
                let mut headers = HeaderMap::new();
                self.append_preflight_headers(&mut headers);
                // set allowed-origin header
                headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());

                resp.status = Some(StatusCode::OK);
                resp.headers = Some(headers);

                Err(resp)
            }
            (_, Some(origin)) => {
                // Simple
                if self.is_origin_allowed(origin) {
                    // set common headers
                    let mut headers = resp_wtr.response_mut().headers_mut();
                    self.append_common_headers(&mut headers);
                    // set allowed-origin header
                    resp_wtr.insert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());

                    return Ok(());
                }

                // If origin is not allowed
                Err(Glitch::bad_request())
            }
            (_, _) => {
                // All other requests are not Cors
                Ok(())
            }
        }
    }

    fn is_method_allowed(&self, header: &HeaderValue) -> bool {
        http::Method::from_bytes(header.as_bytes())
            .map(|method| self.methods_set.contains(&method))
            .unwrap_or(false)
    }

    fn is_header_allowed(&self, header: &str) -> bool {
        HeaderName::from_bytes(header.as_bytes())
            .map(|header| self.allowed_headers_set.contains(&header))
            .unwrap_or(false)
    }

    fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
        if let Some(ref allowed) = self.origins {
            allowed.contains(origin)
        } else {
            true
        }
    }

    fn append_preflight_headers(&self, headers: &mut HeaderMap) {
        self.append_common_headers(headers);

        headers.typed_insert(self.allowed_headers.clone());
        headers.typed_insert(self.methods.clone());

        if let Some(max_age) = self.max_age {
            headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
        }
    }

    fn append_common_headers(&self, headers: &mut HeaderMap) {
        if self.credentials {
            headers.insert(
                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
                HeaderValue::from_static("true"),
            );
        }
        if let Some(expose_headers_header) = &self.exposed_headers {
            headers.typed_insert(expose_headers_header.clone())
        }
    }
}

/// Convenience trait for converting a Url into an Origin for cors
pub trait IntoOrigin {
    /// Convert a Url into an Origin for cors
    fn into_origin(self) -> Origin;
}

impl<'a> IntoOrigin for &'a str {
    fn into_origin(self) -> Origin {
        let mut parts = self.splitn(2, "://");
        let scheme = parts.next().expect("missing scheme");
        let rest = parts.next().expect("missing scheme");

        Origin::try_from_parts(scheme, rest, None).expect("invalid Origin")
    }
}