1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
//! The compress module of roa.
//! This module provides a middleware `Compress`.
//!
//! ### Example
//!
//! ```rust
//! use roa::compress::{Compress, Level};
//! use roa::body::PowerBody;
//! use roa::core::{App, StatusCode, header::ACCEPT_ENCODING};
//! use async_std::task::spawn;
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//!     pretty_env_logger::init();
//!     let (addr, server) = App::new(())
//!         .gate_fn(|ctx, next| async move {
//!             next().await?;
//!             // compress body to 202 bytes in gzip with quantity Level::Fastest.
//!             ctx.resp_mut().await.on_finish(|body| assert_eq!(202, body.consumed()));
//!             Ok(())
//!         })
//!         .gate(Compress(Level::Fastest))
//!         .end(|ctx| async move {
//!             // the size of assets/welcome.html is 236 bytes.
//!             ctx.resp_mut().await.on_finish(|body| assert_eq!(236, body.consumed()));
//!             ctx.write_file("assets/welcome.html").await
//!         })
//!         .run_local()?;
//!     spawn(server);
//!     let client = reqwest::Client::builder().gzip(true).build()?;
//!     let resp = client
//!         .get(&format!("http://{}", addr))
//!         .header(ACCEPT_ENCODING, "gzip")
//!         .send()
//!         .await?;
//!     assert_eq!(StatusCode::OK, resp.status());
//!     Ok(())
//! }
//! ```
pub use async_compression::Level;

use crate::core::header::CONTENT_ENCODING;
use crate::core::{
    async_trait, Body, Context, Error, Middleware, Next, Result, State, StatusCode,
};
use accept_encoding::{parse, Encoding};
use async_compression::futures::bufread::{
    BrotliEncoder, GzipEncoder, ZlibEncoder, ZstdEncoder,
};
use std::sync::Arc;

/// A middleware to negotiate with client and compress response body automatically,
/// supports gzip, deflate, brotli, zstd and identity.
#[derive(Debug, Copy, Clone)]
pub struct Compress(pub Level);

impl Default for Compress {
    fn default() -> Self {
        Self(Level::Default)
    }
}

#[async_trait]
impl<S: State> Middleware<S> for Compress {
    async fn handle(self: Arc<Self>, ctx: Context<S>, next: Next) -> Result {
        next().await?;
        let body: Body = std::mem::take(&mut *ctx.resp_mut().await);
        let content_encoding = match parse(&ctx.req().await.headers)
            .map_err(|err| Error::new(StatusCode::BAD_REQUEST, err, true))?
        {
            None | Some(Encoding::Gzip) => {
                ctx.resp_mut()
                    .await
                    .write(GzipEncoder::with_quality(body, self.0));
                Encoding::Gzip.to_header_value()
            }
            Some(Encoding::Deflate) => {
                ctx.resp_mut()
                    .await
                    .write(ZlibEncoder::with_quality(body, self.0));
                Encoding::Deflate.to_header_value()
            }
            Some(Encoding::Brotli) => {
                ctx.resp_mut()
                    .await
                    .write(BrotliEncoder::with_quality(body, self.0));
                Encoding::Brotli.to_header_value()
            }
            Some(Encoding::Zstd) => {
                ctx.resp_mut()
                    .await
                    .write(ZstdEncoder::with_quality(body, self.0));
                Encoding::Zstd.to_header_value()
            }
            Some(Encoding::Identity) => {
                ctx.resp_mut().await.write_buf(body);
                Encoding::Identity.to_header_value()
            }
        };
        ctx.resp_mut()
            .await
            .headers
            .append(CONTENT_ENCODING, content_encoding);
        Ok(())
    }
}