rama-http 0.3.0-rc1

rama http layers, services and other utilities
use super::predicate::DefaultPredicate;
use super::{Compression, Predicate};
use crate::headers::encoding::AcceptEncoding;
use crate::layer::util::compression::CompressionLevel;
use rama_core::Layer;

/// Compress response bodies of the underlying service.
///
/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the
/// `Content-Encoding` header to responses.
///
/// See the [module docs](crate::layer::compression) for more details.
#[derive(Clone, Debug)]
pub struct CompressionLayer<P = DefaultPredicate> {
    accept: AcceptEncoding,
    predicate: P,
    respect_content_encoding_if_possible: bool,
    quality: CompressionLevel,
    enforce_not_acceptable: bool,
}

impl<P: Default> Default for CompressionLayer<P> {
    fn default() -> Self {
        Self {
            accept: AcceptEncoding::default(),
            predicate: P::default(),
            respect_content_encoding_if_possible: false,
            quality: CompressionLevel::default(),
            enforce_not_acceptable: true,
        }
    }
}

impl<S, P> Layer<S> for CompressionLayer<P>
where
    P: Predicate,
{
    type Service = Compression<S, P>;

    fn layer(&self, inner: S) -> Self::Service {
        Compression {
            inner,
            accept: self.accept,
            predicate: self.predicate.clone(),
            respect_content_encoding_if_possible: self.respect_content_encoding_if_possible,
            quality: self.quality,
            enforce_not_acceptable: self.enforce_not_acceptable,
        }
    }

    fn into_layer(self, inner: S) -> Self::Service {
        Compression {
            inner,
            accept: self.accept,
            predicate: self.predicate,
            respect_content_encoding_if_possible: self.respect_content_encoding_if_possible,
            quality: self.quality,
            enforce_not_acceptable: self.enforce_not_acceptable,
        }
    }
}

impl CompressionLayer {
    /// Creates a new [`CompressionLayer`].
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    /// Replace the current compression predicate.
    pub fn with_compress_predicate<C>(self, predicate: C) -> CompressionLayer<C>
    where
        C: Predicate,
    {
        CompressionLayer {
            accept: self.accept,
            predicate,
            respect_content_encoding_if_possible: self.respect_content_encoding_if_possible,
            quality: self.quality,
            enforce_not_acceptable: self.enforce_not_acceptable,
        }
    }
}

impl<P> CompressionLayer<P> {
    rama_utils::macros::generate_set_and_with! {
        /// Sets whether to enable the gzip encoding.
        pub fn gzip(mut self, enable: bool) -> Self {
            self.accept.set_gzip(enable);
            self
        }
    }

    rama_utils::macros::generate_set_and_with! {
        /// Sets whether to enable the Deflate encoding.
        pub fn deflate(mut self, enable: bool) -> Self {
            self.accept.set_deflate(enable);
            self
        }
    }

    rama_utils::macros::generate_set_and_with! {
        /// Sets whether to enable the Brotli encoding.
        pub fn br(mut self, enable: bool) -> Self {
            self.accept.set_br(enable);
            self
        }
    }

    rama_utils::macros::generate_set_and_with! {
        /// Sets whether to enable the Zstd encoding.
        pub fn zstd(mut self, enable: bool) -> Self {
            self.accept.set_zstd(enable);
            self
        }
    }

    rama_utils::macros::generate_set_and_with! {
        /// Sets the compression quality.
        pub fn quality(mut self, quality: CompressionLevel) -> Self {
            self.quality = quality;
            self
        }
    }

    rama_utils::macros::generate_set_and_with! {
        /// Allow responses with content-encoding.
        ///
        /// Useful in case your stack uses that response header as preference.
        /// Not something you want for regular servers or proxies however,
        /// or most use cases for that matter.
        pub fn respect_content_encoding_if_possible(mut self) -> Self {
            self.respect_content_encoding_if_possible = true;
            self
        }
    }

    rama_utils::macros::generate_set_and_with! {
        /// Sets whether to respond with `406 Not Acceptable` when the client's
        /// `Accept-Encoding` header rejects every available representation
        /// (e.g. `*;q=0` or a lone `identity;q=0`), as recommended by RFC 9110 §12.5.3.
        ///
        /// Enabled by default. Disable to opt out and instead fall back to sending an
        /// uncompressed (identity) response regardless of the client's stated preference.
        pub fn enforce_not_acceptable(mut self, enable: bool) -> Self {
            self.enforce_not_acceptable = enable;
            self
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::{Request, Response, body::util::BodyExt, header::ACCEPT_ENCODING};
    use rama_core::Service;
    use rama_core::service::service_fn;
    use rama_core::stream::io::ReaderStream;
    use rama_http_types::Body;
    use std::convert::Infallible;
    use tokio::fs::File;

    async fn handle(_req: Request) -> Result<Response, Infallible> {
        // Open the file.
        let file = File::open("Cargo.toml").await.expect("file missing");
        // Convert the file into a `Stream`.
        let stream = ReaderStream::new(file);
        // Convert the `Stream` into a `Body`.
        let body = Body::from_stream(stream);
        // Create response.
        Ok(Response::new(body))
    }

    #[tokio::test]
    async fn accept_encoding_configuration_works() -> Result<(), rama_core::error::BoxError> {
        use std::io::Read;

        fn decode<R: Read>(mut r: R) -> std::io::Result<Vec<u8>> {
            let mut buf = Vec::new();
            r.read_to_end(&mut buf)?;
            Ok(buf)
        }

        // Read the source file once so we can verify each response round-trips to the same bytes.
        let expected = tokio::fs::read("Cargo.toml").await?;

        // Configure a layer that only offers deflate, then confirm the response is actually
        // deflate-encoded by decoding it and comparing to the original content.

        let deflate_only_layer = CompressionLayer::new()
            .with_quality(CompressionLevel::Best)
            .with_br(false)
            .with_gzip(false);

        let service = deflate_only_layer.into_layer(service_fn(handle));

        let request = Request::builder()
            .header(ACCEPT_ENCODING, "gzip, deflate, br")
            .body(Body::empty())?;

        let response = service.serve(request).await?;

        assert_eq!(response.headers()["content-encoding"], "deflate");

        let deflate_body = response.into_body().collect().await?.to_bytes();

        // The "deflate" Content-Encoding is RFC 1950 zlib framing (2-byte header + Adler-32),
        // not raw RFC 1951 deflate, so use ZlibDecoder rather than DeflateDecoder.
        let decoded = decode(flate2::bufread::ZlibDecoder::new(&deflate_body[..]))?;
        assert_eq!(decoded, expected);

        // Same check for brotli.
        let br_only_layer = CompressionLayer::new()
            .with_quality(CompressionLevel::Best)
            .with_gzip(false)
            .with_deflate(false);

        let service = br_only_layer.into_layer(service_fn(handle));

        let request = Request::builder()
            .header(ACCEPT_ENCODING, "gzip, deflate, br")
            .body(Body::empty())?;

        let response = service.serve(request).await?;

        assert_eq!(response.headers()["content-encoding"], "br");

        let br_body = response.into_body().collect().await?.to_bytes();

        // 4096 is the decoder's internal read-buffer size, not a content-length bound.
        let decoded = decode(brotli::Decompressor::new(&br_body[..], 4096))?;
        assert_eq!(decoded, expected);

        Ok(())
    }

    #[tokio::test]
    async fn zstd_is_web_safe() -> Result<(), rama_core::error::BoxError> {
        // Test ensuring that zstd compression will not exceed an 8MiB window size; browsers do not
        // accept responses using 16MiB+ window sizes.

        async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
            Ok(Response::new(Body::from(vec![0u8; 18_874_368])))
        }
        // zstd will (I believe) lower its window size if a larger one isn't beneficial and
        // it knows the size of the input; use an 18MiB body to ensure it would want a
        // >=16MiB window (though it might not be able to see the input size here).

        let zstd_layer = CompressionLayer::new()
            .with_quality(CompressionLevel::Best)
            .with_br(false)
            .with_deflate(false)
            .with_gzip(false);

        let service = zstd_layer.into_layer(service_fn(zeroes));

        let request = Request::builder()
            .header(ACCEPT_ENCODING, "zstd")
            .body(Body::empty())?;

        let response = service.serve(request).await?;

        assert_eq!(response.headers()["content-encoding"], "zstd");

        let body = response.into_body();
        let bytes = body.collect().await?.to_bytes();
        let mut dec = zstd::Decoder::new(&*bytes)?;
        dec.window_log_max(23)?; // Limit window size accepted by decoder to 2 ^ 23 bytes (8MiB)

        std::io::copy(&mut dec, &mut std::io::sink())?;

        Ok(())
    }
}