#![cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
use std::io::Read;
use std::io::Write;
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use flate2::Compression as GzLevel;
use flate2::write::DeflateEncoder;
use flate2::write::GzEncoder;
use http::HeaderValue;
use http::StatusCode;
use http::header::ACCEPT_ENCODING;
use http::header::CONTENT_ENCODING;
use http::header::CONTENT_LENGTH;
use http::header::CONTENT_TYPE;
use http::header::VARY;
use http_body_util::BodyExt;
pub mod brotli_stream;
pub mod deflate_stream;
pub mod gzip_stream;
pub mod zstd_stream;
use tako_rs_core::body::TakoBody;
use tako_rs_core::middleware::Next;
use tako_rs_core::plugins::TakoPlugin;
use tako_rs_core::responder::Responder;
use tako_rs_core::router::Router;
use tako_rs_core::types::Request;
use tako_rs_core::types::Response;
#[cfg(feature = "zstd")]
use zstd::stream::encode_all as zstd_encode;
use crate::plugins::compression::brotli_stream::stream_brotli;
use crate::plugins::compression::deflate_stream::stream_deflate;
use crate::plugins::compression::gzip_stream::stream_gzip;
#[cfg(feature = "zstd")]
use crate::plugins::compression::zstd_stream::stream_zstd;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Encoding {
Gzip,
Brotli,
Deflate,
#[cfg(feature = "zstd")]
#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
Zstd,
}
impl Encoding {
fn as_str(&self) -> &'static str {
match self {
Encoding::Gzip => "gzip",
Encoding::Brotli => "br",
Encoding::Deflate => "deflate",
#[cfg(feature = "zstd")]
Encoding::Zstd => "zstd",
}
}
}
#[derive(Clone, Default)]
pub enum ContentTypePolicy {
#[default]
Default,
Exact(Vec<String>),
Prefix(Vec<String>),
Custom(std::sync::Arc<dyn Fn(&str) -> bool + Send + Sync + 'static>),
}
impl ContentTypePolicy {
fn matches(&self, ct: &str) -> bool {
let ct = ct.split(';').next().unwrap_or(ct).trim();
match self {
Self::Default => {
ct.starts_with("text/")
|| ct.contains("json")
|| ct.contains("javascript")
|| ct.contains("xml")
}
Self::Exact(list) => list.iter().any(|m| m.eq_ignore_ascii_case(ct)),
Self::Prefix(list) => {
let lc = ct.to_ascii_lowercase();
list.iter().any(|m| lc.starts_with(&m.to_ascii_lowercase()))
}
Self::Custom(f) => f(ct),
}
}
}
#[derive(Clone)]
pub struct Config {
pub enabled: Vec<Encoding>,
pub min_size: usize,
pub gzip_level: u32,
pub brotli_level: u32,
pub deflate_level: u32,
#[cfg(feature = "zstd")]
pub zstd_level: i32,
pub stream: bool,
pub content_types: ContentTypePolicy,
pub protect_sensitive: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
enabled: vec![Encoding::Gzip, Encoding::Brotli, Encoding::Deflate],
min_size: 1024,
gzip_level: 5,
brotli_level: 5,
deflate_level: 5,
#[cfg(feature = "zstd")]
zstd_level: 3,
stream: false,
content_types: ContentTypePolicy::default(),
protect_sensitive: true,
}
}
}
pub struct CompressionBuilder(Config);
impl Default for CompressionBuilder {
fn default() -> Self {
Self::new()
}
}
impl CompressionBuilder {
pub fn new() -> Self {
Self(Config::default())
}
pub fn enable_gzip(mut self, yes: bool) -> Self {
if yes && !self.0.enabled.contains(&Encoding::Gzip) {
self.0.enabled.push(Encoding::Gzip);
}
if !yes {
self.0.enabled.retain(|e| *e != Encoding::Gzip);
}
self
}
pub fn enable_brotli(mut self, yes: bool) -> Self {
if yes && !self.0.enabled.contains(&Encoding::Brotli) {
self.0.enabled.push(Encoding::Brotli);
}
if !yes {
self.0.enabled.retain(|e| *e != Encoding::Brotli);
}
self
}
pub fn enable_deflate(mut self, yes: bool) -> Self {
if yes && !self.0.enabled.contains(&Encoding::Deflate) {
self.0.enabled.push(Encoding::Deflate);
}
if !yes {
self.0.enabled.retain(|e| *e != Encoding::Deflate);
}
self
}
#[cfg(feature = "zstd")]
#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
pub fn enable_zstd(mut self, yes: bool) -> Self {
if yes && !self.0.enabled.contains(&Encoding::Zstd) {
self.0.enabled.push(Encoding::Zstd);
}
if !yes {
self.0.enabled.retain(|e| *e != Encoding::Zstd);
}
self
}
pub fn enable_stream(mut self, stream: bool) -> Self {
self.0.stream = stream;
self
}
pub fn min_size(mut self, bytes: usize) -> Self {
self.0.min_size = bytes;
self
}
pub fn content_types(mut self, policy: ContentTypePolicy) -> Self {
self.0.content_types = policy;
self
}
pub fn gzip_level(mut self, lvl: u32) -> Self {
self.0.gzip_level = lvl.min(9);
self
}
pub fn brotli_level(mut self, lvl: u32) -> Self {
self.0.brotli_level = lvl.min(11);
self
}
pub fn deflate_level(mut self, lvl: u32) -> Self {
self.0.deflate_level = lvl.min(9);
self
}
#[cfg(feature = "zstd")]
#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
pub fn zstd_level(mut self, lvl: i32) -> Self {
self.0.zstd_level = lvl.clamp(1, 22);
self
}
pub fn protect_sensitive(mut self, on: bool) -> Self {
self.0.protect_sensitive = on;
self
}
pub fn build(self) -> CompressionPlugin {
CompressionPlugin { cfg: self.0 }
}
}
pub enum CompressionResponse<R>
where
R: Responder,
{
Plain(R),
Stream(R),
}
impl<R> Responder for CompressionResponse<R>
where
R: Responder,
{
fn into_response(self) -> Response {
match self {
CompressionResponse::Plain(r) => r.into_response(),
CompressionResponse::Stream(r) => r.into_response(),
}
}
}
#[derive(Clone)]
#[doc(alias = "compression")]
#[doc(alias = "gzip")]
#[doc(alias = "brotli")]
#[doc(alias = "deflate")]
pub struct CompressionPlugin {
cfg: Config,
}
impl Default for CompressionPlugin {
fn default() -> Self {
Self {
cfg: Config::default(),
}
}
}
#[async_trait]
impl TakoPlugin for CompressionPlugin {
fn name(&self) -> &'static str {
"CompressionPlugin"
}
fn setup(&self, router: &Router) -> Result<()> {
let cfg = self.cfg.clone();
router.middleware(move |req, next| {
let cfg = cfg.clone();
let stream = cfg.stream;
async move {
if stream {
CompressionResponse::Stream(
compress_stream_middleware(req, next, cfg)
.await
.into_response(),
)
} else {
CompressionResponse::Plain(compress_middleware(req, next, cfg).await.into_response())
}
}
});
Ok(())
}
}
async fn compress_middleware(req: Request, next: Next, cfg: Config) -> impl Responder {
let accepted = req
.headers()
.get(ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_ascii_lowercase();
let request_is_authenticated = cfg.protect_sensitive && request_carries_credentials(&req);
let mut resp = next.run(req).await;
let chosen = choose_encoding(&accepted, &cfg.enabled);
let status = resp.status();
if !(status.is_success() || status == StatusCode::NOT_MODIFIED) {
return resp.into_response();
}
if resp.headers().contains_key(CONTENT_ENCODING) {
return resp.into_response();
}
if cfg.protect_sensitive
&& (request_is_authenticated || resp.headers().contains_key(http::header::SET_COOKIE))
{
return resp.into_response();
}
if let Some(ct) = resp.headers().get(CONTENT_TYPE) {
let ct = ct.to_str().unwrap_or("");
if !cfg.content_types.matches(ct) {
return resp.into_response();
}
}
ensure_vary_accept_encoding(resp.headers_mut());
let body_bytes = if let Ok(c) = resp.body_mut().collect().await {
c.to_bytes()
} else {
tracing::warn!(
"compression middleware: response body collect() failed; \
returning original status with empty body (no compression)"
);
resp.headers_mut().remove(http::header::CONTENT_ENCODING);
*resp.body_mut() = TakoBody::empty();
return resp.into_response();
};
if body_bytes.len() < cfg.min_size {
*resp.body_mut() = TakoBody::from(body_bytes);
return resp.into_response();
}
if let Some(enc) = chosen {
let compressed = match enc {
Encoding::Gzip => compress_gzip(&body_bytes, cfg.gzip_level).ok(),
Encoding::Brotli => compress_brotli(&body_bytes, cfg.brotli_level).ok(),
Encoding::Deflate => compress_deflate(&body_bytes, cfg.deflate_level).ok(),
#[cfg(feature = "zstd")]
Encoding::Zstd => compress_zstd(&body_bytes, cfg.zstd_level).ok(),
};
if let Some(buf) = compressed {
*resp.body_mut() = TakoBody::from(Bytes::from(buf));
resp
.headers_mut()
.insert(CONTENT_ENCODING, HeaderValue::from_static(enc.as_str()));
resp.headers_mut().remove(CONTENT_LENGTH);
} else {
tracing::warn!(
encoding = enc.as_str(),
"compression failed; serving identity"
);
*resp.body_mut() = TakoBody::from(body_bytes);
resp.headers_mut().remove(CONTENT_ENCODING);
}
} else {
*resp.body_mut() = TakoBody::from(body_bytes);
}
resp.into_response()
}
pub(crate) async fn compress_stream_middleware(
req: Request,
next: Next,
cfg: Config,
) -> impl Responder {
let accepted = req
.headers()
.get(ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_ascii_lowercase();
let request_is_authenticated = cfg.protect_sensitive && request_carries_credentials(&req);
let mut resp = next.run(req).await;
let chosen = choose_encoding(&accepted, &cfg.enabled);
let status = resp.status();
if !(status.is_success() || status == StatusCode::NOT_MODIFIED) {
return resp.into_response();
}
if resp.headers().contains_key(CONTENT_ENCODING) {
return resp.into_response();
}
if cfg.protect_sensitive
&& (request_is_authenticated || resp.headers().contains_key(http::header::SET_COOKIE))
{
return resp.into_response();
}
if let Some(ct) = resp.headers().get(CONTENT_TYPE) {
let ct = ct.to_str().unwrap_or("");
if !cfg.content_types.matches(ct) {
return resp.into_response();
}
}
ensure_vary_accept_encoding(resp.headers_mut());
if let Some(len) = resp
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<usize>().ok())
&& len < cfg.min_size
{
return resp.into_response();
}
if let Some(enc) = chosen {
let body = std::mem::replace(resp.body_mut(), TakoBody::empty());
let new_body = match enc {
Encoding::Gzip => stream_gzip(body, cfg.gzip_level),
Encoding::Brotli => stream_brotli(body, cfg.brotli_level),
Encoding::Deflate => stream_deflate(body, cfg.deflate_level),
#[cfg(feature = "zstd")]
Encoding::Zstd => stream_zstd(body, cfg.zstd_level),
};
*resp.body_mut() = new_body;
resp
.headers_mut()
.insert(CONTENT_ENCODING, HeaderValue::from_static(enc.as_str()));
resp.headers_mut().remove(CONTENT_LENGTH);
}
resp.into_response()
}
fn request_carries_credentials(req: &Request) -> bool {
req.headers().contains_key(http::header::AUTHORIZATION)
|| req
.headers()
.contains_key(http::header::PROXY_AUTHORIZATION)
|| req.headers().contains_key(http::header::COOKIE)
}
fn ensure_vary_accept_encoding(headers: &mut http::HeaderMap) {
let already_present = headers.get_all(VARY).iter().any(|v| {
v.to_str().is_ok_and(|s| {
s.split(',')
.any(|tok| tok.trim().eq_ignore_ascii_case("Accept-Encoding"))
})
});
if !already_present {
headers.append(VARY, HeaderValue::from_static("Accept-Encoding"));
}
}
fn choose_encoding(header: &str, enabled: &[Encoding]) -> Option<Encoding> {
let parsed = parse_accept_encoding(header);
let wildcard_q = parsed.iter().find(|(c, _)| c == "*").map(|(_, q)| *q);
let acceptable = |enc: Encoding| -> bool {
let name = enc.as_str();
match parsed.iter().find(|(c, _)| c == name) {
Some((_, q)) => *q > 0.0,
None => wildcard_q.is_some_and(|q| q > 0.0),
}
};
let server_order: [Encoding; 3] = [Encoding::Brotli, Encoding::Gzip, Encoding::Deflate];
for enc in server_order {
if enabled.contains(&enc) && acceptable(enc) {
return Some(enc);
}
}
#[cfg(feature = "zstd")]
{
if enabled.contains(&Encoding::Zstd) && acceptable(Encoding::Zstd) {
return Some(Encoding::Zstd);
}
}
None
}
fn parse_accept_encoding(header: &str) -> Vec<(String, f32)> {
header
.split(',')
.filter_map(|piece| {
let piece = piece.trim();
if piece.is_empty() {
return None;
}
let mut parts = piece.split(';');
let coding = parts.next()?.trim().to_ascii_lowercase();
if coding.is_empty() {
return None;
}
let mut q: f32 = 1.0;
for param in parts {
let param = param.trim();
let qv = param
.strip_prefix("q=")
.or_else(|| param.strip_prefix("Q="));
if let Some(qv) = qv {
q = qv.parse().ok()?;
}
}
Some((coding, q))
})
.collect()
}
fn compress_gzip(data: &[u8], lvl: u32) -> std::io::Result<Vec<u8>> {
let mut enc = GzEncoder::new(Vec::new(), GzLevel::new(lvl));
enc.write_all(data)?;
enc.finish()
}
fn compress_brotli(data: &[u8], lvl: u32) -> std::io::Result<Vec<u8>> {
let mut out = Vec::new();
brotli::CompressorReader::new(data, 4096, lvl, 22)
.read_to_end(&mut out)
.map_err(|_| std::io::Error::other("Failed to compress data"))?;
Ok(out)
}
fn compress_deflate(data: &[u8], lvl: u32) -> std::io::Result<Vec<u8>> {
let mut enc = DeflateEncoder::new(Vec::new(), flate2::Compression::new(lvl));
enc.write_all(data)?;
enc.finish()
}
#[cfg(feature = "zstd")]
#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
fn compress_zstd(data: &[u8], lvl: i32) -> std::io::Result<Vec<u8>> {
zstd_encode(data, lvl)
}