rkt 0.6.0

Web framework with a focus on usability, security, extensibility, and speed. (Community Fork)
Documentation
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};

use crate::fairing::{Fairing, Info, Kind};
use crate::http::{uncased::UncasedStr, Header};
use crate::shield::{Frame, Hsts, NoSniff, Permission, Policy};
use crate::trace::{Trace, TraceAll};
use crate::{Config, Orbit, Request, Response, Rocket};

/// A [`Fairing`] that injects browser security and privacy headers into all
/// outgoing responses.
///
/// # Usage
///
/// To use `Shield`, first construct an instance of it. To use the default
/// set of headers, construct with [`Shield::default()`](#method.default).
/// For an instance with no preset headers, use [`Shield::new()`]. To
/// enable an additional header, use [`enable()`](Shield::enable()), and to
/// disable a header, use [`disable()`](Shield::disable()):
///
/// ```rust
///
/// use rkt::shield::Shield;
/// use rkt::shield::{XssFilter, ExpectCt};
///
/// // A `Shield` with the default headers:
/// let shield = Shield::default();
///
/// // A `Shield` with the default headers minus `XssFilter`:
/// let shield = Shield::default().disable::<XssFilter>();
///
/// // A `Shield` with the default headers plus `ExpectCt`.
/// let shield = Shield::default().enable(ExpectCt::default());
///
/// // A `Shield` with only `XssFilter` and `ExpectCt`.
/// let shield = Shield::default()
///     .enable(XssFilter::default())
///     .enable(ExpectCt::default());
/// ```
///
/// Then, attach the instance of `Shield` to your application's instance of
/// `Rocket`:
///
/// ```rust
/// # use rkt::shield::Shield;
/// # let shield = Shield::default();
/// rkt::build()
///     // ...
///     .attach(shield)
/// # ;
/// ```
///
/// The fairing will inject all enabled headers into all outgoing responses
/// _unless_ the response already contains a header with the same name. If it
/// does contain the header, a warning is emitted, and the header is not
/// overwritten.
///
/// # TLS and HSTS
///
/// If TLS is configured and enabled when the application is launched in a
/// non-debug profile, HSTS is automatically enabled with its default policy and
/// a warning is logged. To get rid of this warning, explicitly
/// [`Shield::enable()`] an [`Hsts`] policy.
pub struct Shield {
    /// Enabled policies where the key is the header name.
    policies: HashMap<&'static UncasedStr, Header<'static>>,
    /// Whether to enforce HSTS even though the user didn't enable it.
    force_hsts: AtomicBool,
}

impl Clone for Shield {
    fn clone(&self) -> Self {
        Self {
            policies: self.policies.clone(),
            force_hsts: AtomicBool::from(self.force_hsts.load(Ordering::Acquire)),
        }
    }
}

impl Default for Shield {
    /// Returns a new `Shield` instance. See the [table] for a description
    /// of the policies used by default.
    ///
    /// [table]: ./#supported-headers
    ///
    /// # Example
    ///
    /// ```rust
    ///
    /// use rkt::shield::Shield;
    ///
    /// let shield = Shield::default();
    /// ```
    fn default() -> Self {
        Shield::new()
            .enable(NoSniff::default())
            .enable(Frame::default())
            .enable(Permission::default())
    }
}

impl Shield {
    /// Returns an instance of `Shield` with no headers enabled.
    ///
    /// # Example
    ///
    /// ```rust
    ///
    /// use rkt::shield::Shield;
    ///
    /// let shield = Shield::new();
    /// ```
    pub fn new() -> Self {
        Shield {
            policies: HashMap::new(),
            force_hsts: AtomicBool::new(false),
        }
    }

    /// Enables the policy header `policy`.
    ///
    /// If the policy was previously enabled, the configuration is replaced
    /// with that of `policy`.
    ///
    /// # Example
    ///
    /// ```rust
    ///
    /// use rkt::shield::Shield;
    /// use rkt::shield::NoSniff;
    ///
    /// let shield = Shield::new().enable(NoSniff::default());
    /// ```
    pub fn enable<P: Policy>(mut self, policy: P) -> Self {
        self.policies.insert(P::NAME.into(), policy.header());
        self
    }

    /// Disables the policy header `policy`.
    ///
    /// # Example
    ///
    /// ```rust
    ///
    /// use rkt::shield::Shield;
    /// use rkt::shield::NoSniff;
    ///
    /// let shield = Shield::default().disable::<NoSniff>();
    /// ```
    pub fn disable<P: Policy>(mut self) -> Self {
        self.policies.remove(UncasedStr::new(P::NAME));
        self
    }

    /// Returns `true` if the policy `P` is enabled.
    ///
    /// # Example
    ///
    /// ```rust
    ///
    /// use rkt::shield::Shield;
    /// use rkt::shield::{Permission, NoSniff, Frame};
    /// use rkt::shield::{Prefetch, ExpectCt, Referrer};
    ///
    /// let shield = Shield::default();
    ///
    /// assert!(shield.is_enabled::<NoSniff>());
    /// assert!(shield.is_enabled::<Frame>());
    /// assert!(shield.is_enabled::<Permission>());
    ///
    /// assert!(!shield.is_enabled::<Prefetch>());
    /// assert!(!shield.is_enabled::<ExpectCt>());
    /// assert!(!shield.is_enabled::<Referrer>());
    /// ```
    pub fn is_enabled<P: Policy>(&self) -> bool {
        self.policies.contains_key(UncasedStr::new(P::NAME))
    }
}

#[crate::async_trait]
impl Fairing for Shield {
    fn info(&self) -> Info {
        Info {
            name: "Shield",
            kind: Kind::Liftoff | Kind::Response | Kind::Singleton,
        }
    }

    async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
        if self.policies.is_empty() {
            return;
        }

        let force_hsts = rocket.endpoints().all(|v| v.is_tls())
            && rocket.figment().profile() != Config::DEBUG_PROFILE
            && !self.is_enabled::<Hsts>();

        if force_hsts {
            self.force_hsts.store(true, Ordering::Release);
        }

        span_info!("shield", policies = self.policies.len() => {
            self.policies.values().trace_all_info();

            if force_hsts {
                warn!("Detected TLS-enabled liftoff without enabling HSTS.\n\
                    Shield has enabled a default HSTS policy.\n\
                    To remove this warning, configure an HSTS policy.");
            }
        })
    }

    async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
        // Set all of the headers in `self.policies` in `response` as long as
        // the header is not already in the response.
        for header in self.policies.values() {
            if response.headers().contains(header.name()) {
                span_warn!("shield", "shield refusing to overwrite existing response header" => {
                    header.trace_warn();
                });

                continue;
            }

            response.set_header(header.clone());
        }
    }
}