#![cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
use std::io::Read;
use std::io::Write;
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use flate2::Compression as GzLevel;
use flate2::write::DeflateEncoder;
use flate2::write::GzEncoder;
use http::HeaderValue;
use http::StatusCode;
use http::header::ACCEPT_ENCODING;
use http::header::CONTENT_ENCODING;
use http::header::CONTENT_LENGTH;
use http::header::CONTENT_TYPE;
use http::header::VARY;
use http_body_util::BodyExt;
pub mod brotli_stream;
pub mod deflate_stream;
pub mod gzip_stream;
pub mod zstd_stream;
#[cfg(feature = "zstd")]
use zstd::stream::encode_all as zstd_encode;
use crate::body::TakoBody;
use crate::middleware::Next;
use crate::plugins::TakoPlugin;
use crate::plugins::compression::brotli_stream::stream_brotli;
use crate::plugins::compression::deflate_stream::stream_deflate;
use crate::plugins::compression::gzip_stream::stream_gzip;
#[cfg(feature = "zstd")]
use crate::plugins::compression::zstd_stream::stream_zstd;
use crate::responder::Responder;
use crate::router::Router;
use crate::types::Request;
use crate::types::Response;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Encoding {
Gzip,
Brotli,
Deflate,
#[cfg(feature = "zstd")]
#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
Zstd,
}
impl Encoding {
fn as_str(&self) -> &'static str {
match self {
Encoding::Gzip => "gzip",
Encoding::Brotli => "br",
Encoding::Deflate => "deflate",
#[cfg(feature = "zstd")]
Encoding::Zstd => "zstd",
}
}
}
#[derive(Clone)]
pub struct Config {
pub enabled: Vec<Encoding>,
pub min_size: usize,
pub gzip_level: u32,
pub brotli_level: u32,
pub deflate_level: u32,
#[cfg(feature = "zstd")]
pub zstd_level: i32,
pub stream: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
enabled: vec![Encoding::Gzip, Encoding::Brotli, Encoding::Deflate],
min_size: 1024,
gzip_level: 5,
brotli_level: 5,
deflate_level: 5,
#[cfg(feature = "zstd")]
zstd_level: 3,
stream: false,
}
}
}
pub struct CompressionBuilder(Config);
impl CompressionBuilder {
pub fn new() -> Self {
Self(Config::default())
}
pub fn enable_gzip(mut self, yes: bool) -> Self {
if yes && !self.0.enabled.contains(&Encoding::Gzip) {
self.0.enabled.push(Encoding::Gzip)
}
if !yes {
self.0.enabled.retain(|e| *e != Encoding::Gzip)
}
self
}
pub fn enable_brotli(mut self, yes: bool) -> Self {
if yes && !self.0.enabled.contains(&Encoding::Brotli) {
self.0.enabled.push(Encoding::Brotli)
}
if !yes {
self.0.enabled.retain(|e| *e != Encoding::Brotli)
}
self
}
pub fn enable_deflate(mut self, yes: bool) -> Self {
if yes && !self.0.enabled.contains(&Encoding::Deflate) {
self.0.enabled.push(Encoding::Deflate)
}
if !yes {
self.0.enabled.retain(|e| *e != Encoding::Deflate)
}
self
}
#[cfg(feature = "zstd")]
#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
pub fn enable_zstd(mut self, yes: bool) -> Self {
if yes && !self.0.enabled.contains(&Encoding::Zstd) {
self.0.enabled.push(Encoding::Zstd)
}
if !yes {
self.0.enabled.retain(|e| *e != Encoding::Zstd)
}
self
}
pub fn enable_stream(mut self, stream: bool) -> Self {
self.0.stream = stream;
self
}
pub fn min_size(mut self, bytes: usize) -> Self {
self.0.min_size = bytes;
self
}
pub fn gzip_level(mut self, lvl: u32) -> Self {
self.0.gzip_level = lvl.min(9);
self
}
pub fn brotli_level(mut self, lvl: u32) -> Self {
self.0.brotli_level = lvl.min(11);
self
}
pub fn deflate_level(mut self, lvl: u32) -> Self {
self.0.deflate_level = lvl.min(9);
self
}
#[cfg(feature = "zstd")]
#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
pub fn zstd_level(mut self, lvl: i32) -> Self {
self.0.zstd_level = lvl.clamp(1, 22);
self
}
pub fn build(self) -> CompressionPlugin {
CompressionPlugin { cfg: self.0 }
}
}
pub enum CompressionResponse<R>
where
R: Responder,
{
Plain(R),
Stream(R),
}
impl<R> Responder for CompressionResponse<R>
where
R: Responder,
{
fn into_response(self) -> Response {
match self {
CompressionResponse::Plain(r) => r.into_response(),
CompressionResponse::Stream(r) => r.into_response(),
}
}
}
#[derive(Clone)]
#[doc(alias = "compression")]
#[doc(alias = "gzip")]
#[doc(alias = "brotli")]
#[doc(alias = "deflate")]
pub struct CompressionPlugin {
cfg: Config,
}
impl Default for CompressionPlugin {
fn default() -> Self {
Self {
cfg: Config::default(),
}
}
}
#[async_trait]
impl TakoPlugin for CompressionPlugin {
fn name(&self) -> &'static str {
"CompressionPlugin"
}
fn setup(&self, router: &Router) -> Result<()> {
let cfg = self.cfg.clone();
router.middleware(move |req, next| {
let cfg = cfg.clone();
let stream = cfg.stream.clone();
async move {
if stream == false {
return CompressionResponse::Plain(
compress_middleware(req, next, cfg).await.into_response(),
);
} else {
return CompressionResponse::Stream(
compress_stream_middleware(req, next, cfg)
.await
.into_response(),
);
}
}
});
Ok(())
}
}
async fn compress_middleware(req: Request, next: Next, cfg: Config) -> impl Responder {
let accepted = req
.headers()
.get(ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_ascii_lowercase();
let mut resp = next.run(req).await;
let chosen = choose_encoding(&accepted, &cfg.enabled);
let status = resp.status();
if !(status.is_success() || status == StatusCode::NOT_MODIFIED) {
return resp.into_response();
}
if resp.headers().contains_key(CONTENT_ENCODING) {
return resp.into_response();
}
if let Some(ct) = resp.headers().get(CONTENT_TYPE) {
let ct = ct.to_str().unwrap_or("");
if !(ct.starts_with("text/")
|| ct.contains("json")
|| ct.contains("javascript")
|| ct.contains("xml"))
{
return resp.into_response();
}
}
let body_bytes = match resp.body_mut().collect().await {
Ok(c) => c.to_bytes(),
Err(_) => {
*resp.status_mut() = StatusCode::BAD_GATEWAY;
*resp.body_mut() = TakoBody::empty();
return resp.into_response();
}
};
if body_bytes.len() < cfg.min_size {
*resp.body_mut() = TakoBody::from(Bytes::from(body_bytes));
return resp.into_response();
}
if let Some(enc) = chosen {
let compressed = match enc {
Encoding::Gzip => {
compress_gzip(&body_bytes, cfg.gzip_level).unwrap_or_else(|_| body_bytes.to_vec())
}
Encoding::Brotli => {
compress_brotli(&body_bytes, cfg.brotli_level).unwrap_or_else(|_| body_bytes.to_vec())
}
Encoding::Deflate => {
compress_deflate(&body_bytes, cfg.deflate_level).unwrap_or_else(|_| body_bytes.to_vec())
}
#[cfg(feature = "zstd")]
Encoding::Zstd => {
compress_zstd(&body_bytes, cfg.zstd_level).unwrap_or_else(|_| body_bytes.to_vec())
}
};
*resp.body_mut() = TakoBody::from(Bytes::from(compressed));
resp
.headers_mut()
.insert(CONTENT_ENCODING, HeaderValue::from_static(enc.as_str()));
resp.headers_mut().remove(CONTENT_LENGTH);
resp
.headers_mut()
.insert(VARY, HeaderValue::from_static("Accept-Encoding"));
} else {
*resp.body_mut() = TakoBody::from(Bytes::from(body_bytes));
}
resp.into_response()
}
pub async fn compress_stream_middleware(req: Request, next: Next, cfg: Config) -> impl Responder {
let accepted = req
.headers()
.get(ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_ascii_lowercase();
let mut resp = next.run(req).await;
let chosen = choose_encoding(&accepted, &cfg.enabled);
let status = resp.status();
if !(status.is_success() || status == StatusCode::NOT_MODIFIED) {
return resp.into_response();
}
if resp.headers().contains_key(CONTENT_ENCODING) {
return resp.into_response();
}
if let Some(ct) = resp.headers().get(CONTENT_TYPE) {
let ct = ct.to_str().unwrap_or("");
if !(ct.starts_with("text/")
|| ct.contains("json")
|| ct.contains("javascript")
|| ct.contains("xml"))
{
return resp.into_response();
}
}
if let Some(len) = resp
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<usize>().ok())
{
if len < cfg.min_size {
return resp.into_response();
}
}
if let Some(enc) = chosen {
let body = std::mem::replace(resp.body_mut(), TakoBody::empty());
let new_body = match enc {
Encoding::Gzip => stream_gzip(body, cfg.gzip_level),
Encoding::Brotli => stream_brotli(body, cfg.brotli_level),
Encoding::Deflate => stream_deflate(body, cfg.deflate_level),
#[cfg(feature = "zstd")]
Encoding::Zstd => stream_zstd(body, cfg.zstd_level),
};
*resp.body_mut() = new_body;
resp
.headers_mut()
.insert(CONTENT_ENCODING, HeaderValue::from_static(enc.as_str()));
resp.headers_mut().remove(CONTENT_LENGTH);
resp
.headers_mut()
.insert(VARY, HeaderValue::from_static("Accept-Encoding"));
}
resp.into_response()
}
fn choose_encoding(header: &str, enabled: &[Encoding]) -> Option<Encoding> {
let header = header.to_ascii_lowercase();
let test = |e: Encoding| header.contains(e.as_str()) && enabled.contains(&e);
if test(Encoding::Brotli) {
Some(Encoding::Brotli)
} else if test(Encoding::Gzip) {
Some(Encoding::Gzip)
} else if test(Encoding::Deflate) {
Some(Encoding::Deflate)
} else {
#[cfg(feature = "zstd")]
{
if test(Encoding::Zstd) {
return Some(Encoding::Zstd);
}
}
None
}
}
fn compress_gzip(data: &[u8], lvl: u32) -> std::io::Result<Vec<u8>> {
let mut enc = GzEncoder::new(Vec::new(), GzLevel::new(lvl));
enc.write_all(data)?;
enc.finish()
}
fn compress_brotli(data: &[u8], lvl: u32) -> std::io::Result<Vec<u8>> {
let mut out = Vec::new();
brotli::CompressorReader::new(data, 4096, lvl, 22)
.read_to_end(&mut out)
.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Failed to compress data"))?;
Ok(out)
}
fn compress_deflate(data: &[u8], lvl: u32) -> std::io::Result<Vec<u8>> {
let mut enc = DeflateEncoder::new(Vec::new(), flate2::Compression::new(lvl));
enc.write_all(data)?;
enc.finish()
}
#[cfg(feature = "zstd")]
#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
fn compress_zstd(data: &[u8], lvl: i32) -> std::io::Result<Vec<u8>> {
zstd_encode(data, lvl)
}