kvarn 0.5.0

A forward-thinking fast web server designed to fit your needs, efficiently.
Documentation
//! [CORS](https://en.wikipedia.org/wiki/Cross-origin_resource_sharing) implementation for Kvarn.
//!
//! See [`Cors`] for usage details.

use crate::prelude::*;
use extensions::RuleSet;

/// A [CORS](https://en.wikipedia.org/wiki/Cross-origin_resource_sharing) ruleset for Kvarn.
///
/// Use [`Extensions::with_cors`] to allow selected CORS requests.
///
/// By default, Kvarn uses a empty [`RuleSet`]; all CORS requests are rejected.
///
/// # Examples
///
/// ```
/// # use kvarn::prelude::*;
/// // Allow `https://icelk.dev` and `https://kvarn.org` to access all images.
/// // Also allow all requests from `http://example.org` access to the api.
/// let cors =
///     Cors::empty()
///         .add(
///             "/images/*",
///             CorsAllowList::new(Duration::from_secs(60*60*24*365))
///                 .add_origin("https://icelk.dev")
///                 .add_origin("https://kvarn.org")
///             )
///         .add(
///             "/api/*",
///             CorsAllowList::default()
///                 .add_origin("http://example.org")
///                 .add_method(Method::PUT)
///                 .add_method(Method::POST)
///         );
/// ```
pub type Cors = RuleSet<AllowList>;
impl Cors {
    /// Check if the (cross-origin) request's `origin` [`Uri`] is allowed by the CORS rules.
    ///
    /// See [`CorsAllowList::check`] for info about the return types.
    pub fn check_origin(
        &self,
        origin: &Uri,
        uri_path: &str,
    ) -> Option<(MethodAllowList, &[HeaderName], Duration)> {
        self.get(uri_path).and_then(|cal| cal.check(origin))
    }
    /// Check if the [`Request::headers`] and [`Request::uri`] is allowed with this ruleset.
    ///
    /// > This will not check for errors in `access-control-request-headers`.
    ///
    /// Use this over [`Self::check_origin`] because this checks for `same_origin` requests.
    ///
    /// See [`CorsAllowList::check`] for info about the return types.
    pub fn check_cors_request<T>(
        &self,
        request: &Request<T>,
    ) -> Option<(MethodAllowList, &[HeaderName], Duration)> {
        let same_origin_allowed_headers = (
            MethodAllowList::All,
            &[][..],
            Duration::from_secs(60 * 60 * 24 * 7),
        );
        match request.headers().get("origin") {
            None => Some(same_origin_allowed_headers),
            Some(origin)
                if origin.to_str().map_or(false, |origin| {
                    Cors::is_part_of_origin(origin, request.uri())
                }) =>
            {
                Some(same_origin_allowed_headers)
            }
            Some(origin) => match Uri::try_from(origin.as_bytes()) {
                Ok(origin) => match self.check_origin(&origin, request.uri().path()) {
                    Some(allowed) if allowed.0.allowed(request.method()) => Some(allowed),
                    _ => None,
                },
                Err(_) => None,
            },
        }
    }
    /// Checks if `uri` is the same origin as `origin`.
    fn is_part_of_origin(origin: &str, uri: &Uri) -> bool {
        let uri_parts = origin.split_once("://");

        let (origin_scheme, origin_authority) = match uri_parts {
            Some((s, o)) => (s, o),
            None => return origin == "localhost" || origin == "null",
        };
        if Some(origin_scheme) != uri.scheme_str() {
            return false;
        }
        uri.authority()
            .map(uri::Authority::as_str)
            .map_or(false, |authority| authority == origin_authority)
    }
}

/// A CORS allow list which allowes hosts, methods, and headers from a associated path.
/// This is a builder-like struct.
/// Use the `add_*` methods to add allowed origins, methods, and headers.
/// Multiple allow lists can be added to a [`Cors`] instance.
/// See the example at [`Cors`].
///
/// Use [`RuleSet::add`] to add a rule.
#[must_use]
#[derive(Debug)]
pub struct AllowList {
    allowed: Vec<Uri>,
    allow_all_origins: bool,
    methods: Option<Vec<Method>>,
    headers: Vec<HeaderName>,
    cache_for: Duration,
}
impl AllowList {
    /// Creates a empty CORS allow list with the client cache duration of `cache_for`.
    pub fn new(cache_for: Duration) -> Self {
        Self {
            allowed: Vec::new(),
            allow_all_origins: false,
            methods: Some(vec![Method::GET, Method::HEAD, Method::OPTIONS]),
            headers: Vec::new(),
            cache_for,
        }
    }
    /// Allows CORS request from `allowed_origin`.
    /// Note that the scheme (`https` / `http`) is sensitive.
    /// Use [`Self::add_origin_uri`] for a [`Uri`] input.
    ///
    /// # Panics
    ///
    /// Panics if `allowed_origin` is not a valid [`Uri`]
    /// or if it doesn't contain a host AND a scheme.
    pub fn add_origin(self, allowed_origin: impl AsRef<str>) -> Self {
        self.add_origin_uri(Uri::try_from(allowed_origin.as_ref()).unwrap())
    }
    /// Allows CORS request from `allowed_origin`.
    /// Note that the scheme (`https` / `http`) is sensitive.
    ///
    /// # Panics
    ///
    /// Panics if `allowed_origin` doesn't contain a host AND a scheme.
    pub fn add_origin_uri(mut self, allowed_origin: Uri) -> Self {
        assert!(allowed_origin.host().is_some());
        assert!(allowed_origin.scheme().is_some());
        self.allowed.push(allowed_origin);
        self
    }
    /// Enables the flag to allow all origins to use the set methods and headers in CORS requests.
    pub fn allow_all_origins(mut self) -> Self {
        self.allow_all_origins = true;
        self
    }
    /// Allows the listed origin(s) (added via [`Self::add_origin`])
    /// to request using `allowed_method`.
    pub fn add_method(mut self, allowed_method: Method) -> Self {
        let methods = self.methods.get_or_insert_with(Vec::new);
        if !methods.contains(&allowed_method) {
            methods.push(allowed_method);
        }
        self
    }
    /// Allows all methods.
    pub fn allow_all_methods(mut self) -> Self {
        self.methods = None;
        self
    }
    /// Allows the listed origin(s) (added via [`Self::add_origin`])
    /// to send the `allowed_header` in the request.
    pub fn add_header(mut self, allowed_header: HeaderName) -> Self {
        if !self.headers.contains(&allowed_header) {
            self.headers.push(allowed_header);
        }
        self
    }
    fn get_methods(&self) -> MethodAllowList {
        self.methods
            .as_deref()
            .map_or(MethodAllowList::All, MethodAllowList::Selected)
    }
    /// Checks if the `origin` is allowed according to the allow list.
    ///
    /// Returns [`Some`] if `origin` is allowed, with the [`Method`]s and [`HeaderName`]s
    /// allowed, with a cache max-age of [`Duration`].
    /// Returns [`None`] if `origin` isn't allowed.
    pub fn check(&self, origin: &Uri) -> Option<(MethodAllowList, &[HeaderName], Duration)> {
        if self.allow_all_origins {
            return Some((self.get_methods(), &self.headers, self.cache_for));
        }
        for allowed in &self.allowed {
            let scheme = allowed.scheme().map_or("https", uri::Scheme::as_str);
            // This is OK; we assert it has a host when we add it
            if Some(allowed.host().unwrap()) == origin.host()
                && allowed.port_u16() == origin.port_u16()
                && Some(scheme) == origin.scheme().map(uri::Scheme::as_str)
            {
                return Some((self.get_methods(), &self.headers, self.cache_for));
            }
        }
        None
    }
}
/// The default `cache_for` is 1 hour.
impl Default for AllowList {
    fn default() -> Self {
        Self::new(Duration::from_secs(60 * 60))
    }
}

/// The allowed methods.
#[derive(Debug)]
#[must_use]
pub enum MethodAllowList<'a> {
    /// All methods are allowed.
    All,
    /// Only the methods in the slice are allowed.
    Selected(&'a [Method]),
}
impl<'a> MethodAllowList<'a> {
    #[must_use]
    fn allowed(&self, method: &Method) -> bool {
        match self {
            Self::All => true,
            Self::Selected(list) => list.contains(method),
        }
    }
    fn to_bytes(&self) -> Bytes {
        match self {
            Self::All => Bytes::from_static(b"*"),
            Self::Selected(list) => list
                .iter()
                .enumerate()
                .fold(BytesMut::with_capacity(24), |mut acc, (pos, method)| {
                    acc.extend_from_slice(method.as_str().as_bytes());
                    if pos + 1 != list.len() {
                        acc.extend_from_slice(b", ");
                    }
                    acc
                })
                .freeze(),
        }
    }
}

fn options_prepare(options_cors_settings: Arc<Cors>) -> Prepare {
    prepare!(request, _, _, _, move |options_cors_settings: Arc<Cors>| {
        let allowed = options_cors_settings.check_cors_request(request);

        if allowed.is_none() {
            return {
                let response = Response::builder()
                    .status(StatusCode::FORBIDDEN)
                    .body(Bytes::from_static(b"CORS request denied"))
                    .expect("we know this is a good request.");
                FatResponse::new(response, comprash::ServerCachePreference::Full)
            };
        }

        let mut builder = Response::builder().status(StatusCode::NO_CONTENT);

        if let Some((methods, headers, cache_for)) = allowed {
            let methods = methods.to_bytes();
            let headers = headers
                .iter()
                .enumerate()
                .fold(BytesMut::with_capacity(24), |mut acc, (pos, header)| {
                    acc.extend_from_slice(header.as_str().as_bytes());
                    if pos + 1 != headers.len() {
                        acc.extend_from_slice(b", ");
                    }
                    acc
                })
                .freeze();

            builder = builder
                .header(
                    "access-control-allow-methods",
                    // We know all the characters from [`Method::as_str`] are valid.
                    HeaderValue::from_maybe_shared(methods).unwrap(),
                )
                .header(
                    "access-control-allow-headers",
                    // We know all the characters from [`HeaderName::as_str()`] are valid.
                    // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#impl-From%3CHeaderName%3E
                    HeaderValue::from_maybe_shared(headers).unwrap(),
                )
                .header(
                    "access-control-max-age",
                    // We know a number is valid
                    HeaderValue::try_from(
                        // if > second integer, add 1 second (ceil the duration).
                        // i64::from(bool) returns 1 if true.
                        (cache_for.as_secs() + u64::from(cache_for.subsec_nanos() > 0)).to_string(),
                    )
                    .unwrap(),
                );
        }

        let response = builder.body(Bytes::new()).unwrap_or_else(|_| {
            Response::builder()
                .status(StatusCode::INTERNAL_SERVER_ERROR)
                .body(utils::hardcoded_error_body(
                    StatusCode::INTERNAL_SERVER_ERROR,
                    None,
                ))
                .expect("this is a good response.")
        });
        FatResponse::new(response, comprash::ServerCachePreference::None)
    })
}
impl Extensions {
    /// Adds extensions to disallow all CORS requests.
    /// This is added when calling [`Extensions::new`].
    pub fn with_disallow_cors(&mut self) -> &mut Self {
        self.add_prime(
            prime!(request, _, _, {
                let missmatch = request
                    .headers()
                    .get("origin")
                    .and_then(|origin| origin.to_str().ok())
                    .map_or(false, |origin| {
                        !Cors::is_part_of_origin(origin, request.uri())
                    });
                if missmatch {
                    Some(Uri::from_static("/./cors_fail"))
                } else {
                    None
                }
            }),
            Id::new(16_777_216, "Reroute all CORS requests to /./cors_fail"),
        );

        self.add_prepare_single(
            "/./cors_fail",
            prepare!(_, _, _, _, {
                let response = Response::builder()
                    .status(StatusCode::FORBIDDEN)
                    .body(Bytes::from_static(b"CORS request denied"))
                    .expect("we know this is a good request.");
                FatResponse::new(response, comprash::ServerCachePreference::Full)
            }),
        );
        self.add_prepare_single("/./cors_options", options_prepare(Cors::empty().arc()));
        self.add_prime(
            prime!(request, _, _, {
                if request.method() == Method::OPTIONS
                    && request.headers().get("origin").is_some()
                    && request
                        .headers()
                        .get("access-control-request-method")
                        .is_some()
                {
                    Some(Uri::from_static("/./cors_options"))
                } else {
                    None
                }
            }),
            Id::new(16_777_215, "Provides CORS preflight request support"),
        );
        self
    }
    /// Overrides the default handling (deny all) of CORS requests to be `cors_settings`.
    ///
    /// See [`Cors`] for an example and more info.
    pub fn with_cors(&mut self, cors_settings: Arc<Cors>) -> &mut Self {
        self.with_disallow_cors();

        let options_cors_settings = Arc::clone(&cors_settings);
        self.add_prepare_single("/./cors_options", options_prepare(options_cors_settings));
        let package_cors_settings = Arc::clone(&cors_settings);

        // This priority have to be higher than the one in the [`Self::add_disallow_cors`]'s prime
        // extension.
        self.add_prime(
            prime!(request, _, _, move |cors_settings: Arc<Cors>| {
                let allow = cors_settings.check_cors_request(request);
                if allow.is_some() {
                    None
                } else {
                    Some(Uri::from_static("/./cors_fail"))
                }
            }),
            Id::new(
                16_777_216,
                "Reroute not allowed CORS request to /./cors_fail",
            ),
        );

        // Low priority so it runs last.
        self.add_package(
            package!(response, request, _, move |package_cors_settings: Arc<
                Cors,
            >| {
                if let Some(origin) = request.headers().get("origin") {
                    let allowed = package_cors_settings.check_cors_request(request).is_some();
                    if allowed {
                        response
                            .headers_mut()
                            .insert("access-control-allow-origin", origin.clone());
                    }
                }
            }),
            Id::new(
                -1024,
                "Adds access-control-allow-origin depending on if CORS request is allowed",
            ),
        );

        self
    }
}