use super::HttpTask;
use bytes::Bytes;
use log::warn;
use pingora_error::{ErrorType, Result};
use pingora_http::{RequestHeader, ResponseHeader};
use std::time::Duration;
use strum::EnumCount;
use strum_macros::EnumCount as EnumCountMacro;
mod brotli;
mod gzip;
mod zstd;
pub const COMPRESSION_ERROR: ErrorType = ErrorType::new("CompressionError");
pub trait Encode {
fn encode(&mut self, input: &[u8], end: bool) -> Result<Bytes>;
fn stat(&self) -> (&'static str, usize, usize, Duration);
}
pub struct ResponseCompressionCtx(CtxInner);
enum CtxInner {
HeaderPhase {
accept_encoding: Vec<Algorithm>,
encoding_levels: [u32; Algorithm::COUNT],
decompress_enable: [bool; Algorithm::COUNT],
preserve_etag: [bool; Algorithm::COUNT],
},
BodyPhase(Option<Box<dyn Encode + Send + Sync>>),
}
impl ResponseCompressionCtx {
pub fn new(compression_level: u32, decompress_enable: bool, preserve_etag: bool) -> Self {
Self(CtxInner::HeaderPhase {
accept_encoding: Vec::new(),
encoding_levels: [compression_level; Algorithm::COUNT],
decompress_enable: [decompress_enable; Algorithm::COUNT],
preserve_etag: [preserve_etag; Algorithm::COUNT],
})
}
pub fn is_enabled(&self) -> bool {
match &self.0 {
CtxInner::HeaderPhase {
decompress_enable,
encoding_levels: levels,
..
} => levels.iter().any(|l| *l != 0) || decompress_enable.iter().any(|d| *d),
CtxInner::BodyPhase(c) => c.is_some(),
}
}
pub fn get_info(&self) -> Option<(&'static str, usize, usize, Duration)> {
match &self.0 {
CtxInner::HeaderPhase { .. } => None,
CtxInner::BodyPhase(c) => c.as_ref().map(|c| c.stat()),
}
}
pub fn adjust_level(&mut self, new_level: u32) {
match &mut self.0 {
CtxInner::HeaderPhase {
encoding_levels: levels,
..
} => {
*levels = [new_level; Algorithm::COUNT];
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
pub fn adjust_algorithm_level(&mut self, algorithm: Algorithm, new_level: u32) {
match &mut self.0 {
CtxInner::HeaderPhase {
encoding_levels: levels,
..
} => {
levels[algorithm.index()] = new_level;
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
pub fn adjust_decompression(&mut self, enabled: bool) {
match &mut self.0 {
CtxInner::HeaderPhase {
decompress_enable, ..
} => {
*decompress_enable = [enabled; Algorithm::COUNT];
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
pub fn adjust_algorithm_decompression(&mut self, algorithm: Algorithm, enabled: bool) {
match &mut self.0 {
CtxInner::HeaderPhase {
decompress_enable, ..
} => {
decompress_enable[algorithm.index()] = enabled;
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
pub fn adjust_preserve_etag(&mut self, enabled: bool) {
match &mut self.0 {
CtxInner::HeaderPhase { preserve_etag, .. } => {
*preserve_etag = [enabled; Algorithm::COUNT];
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
pub fn adjust_algorithm_preserve_etag(&mut self, algorithm: Algorithm, enabled: bool) {
match &mut self.0 {
CtxInner::HeaderPhase { preserve_etag, .. } => {
preserve_etag[algorithm.index()] = enabled;
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
pub fn request_filter(&mut self, req: &RequestHeader) {
if !self.is_enabled() {
return;
}
match &mut self.0 {
CtxInner::HeaderPhase {
accept_encoding, ..
} => parse_accept_encoding(
req.headers.get(http::header::ACCEPT_ENCODING),
accept_encoding,
),
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
pub fn response_header_filter(&mut self, resp: &mut ResponseHeader, end: bool) {
if !self.is_enabled() {
return;
}
match &self.0 {
CtxInner::HeaderPhase {
decompress_enable,
preserve_etag,
accept_encoding,
encoding_levels: levels,
} => {
if resp.status.is_informational() {
if resp.status == http::status::StatusCode::SWITCHING_PROTOCOLS {
self.0 = CtxInner::BodyPhase(None);
}
return;
}
if end {
self.0 = CtxInner::BodyPhase(None);
return;
}
if depends_on_accept_encoding(
resp,
levels.iter().any(|level| *level != 0),
decompress_enable,
) {
add_vary_header(resp, &http::header::ACCEPT_ENCODING);
}
let action = decide_action(resp, accept_encoding);
let (encoder, preserve_etag) = match action {
Action::Noop => (None, false),
Action::Compress(algorithm) => {
let idx = algorithm.index();
(algorithm.compressor(levels[idx]), preserve_etag[idx])
}
Action::Decompress(algorithm) => {
let idx = algorithm.index();
(
algorithm.decompressor(decompress_enable[idx]),
preserve_etag[idx],
)
}
};
if encoder.is_some() {
adjust_response_header(resp, &action, preserve_etag);
}
self.0 = CtxInner::BodyPhase(encoder);
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
pub fn response_body_filter(&mut self, data: Option<&Bytes>, end: bool) -> Option<Bytes> {
match &mut self.0 {
CtxInner::HeaderPhase { .. } => panic!("Wrong phase: HeaderPhase"),
CtxInner::BodyPhase(compressor) => {
let result = compressor
.as_mut()
.map(|c| {
let data = if let Some(b) = data { b.as_ref() } else { &[] };
c.encode(data, end)
})
.transpose();
result.unwrap_or_else(|e| {
warn!("Failed to compress, compression disabled, {}", e);
self.0 = CtxInner::BodyPhase(None);
None
})
}
}
}
pub fn response_filter(&mut self, t: &mut HttpTask) {
if !self.is_enabled() {
return;
}
match t {
HttpTask::Header(resp, end) => self.response_header_filter(resp, *end),
HttpTask::Body(data, end) => {
let compressed = self.response_body_filter(data.as_ref(), *end);
if compressed.is_some() {
*t = HttpTask::Body(compressed, *end);
}
}
HttpTask::Done => {
let compressed = self.response_body_filter(None, true);
if compressed.is_some() {
*t = HttpTask::Body(compressed, true);
}
}
_ => { }
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumCountMacro)]
pub enum Algorithm {
Any, Gzip,
Brotli,
Zstd,
Other, }
impl Algorithm {
pub fn as_str(&self) -> &'static str {
match self {
Algorithm::Gzip => "gzip",
Algorithm::Brotli => "br",
Algorithm::Zstd => "zstd",
Algorithm::Any => "*",
Algorithm::Other => "other",
}
}
pub fn compressor(&self, level: u32) -> Option<Box<dyn Encode + Send + Sync>> {
if level == 0 {
None
} else {
match self {
Self::Gzip => Some(Box::new(gzip::Compressor::new(level))),
Self::Brotli => Some(Box::new(brotli::Compressor::new(level))),
Self::Zstd => Some(Box::new(zstd::Compressor::new(level))),
_ => None, }
}
}
pub fn decompressor(&self, enabled: bool) -> Option<Box<dyn Encode + Send + Sync>> {
if !enabled {
None
} else {
match self {
Self::Gzip => Some(Box::new(gzip::Decompressor::new())),
Self::Brotli => Some(Box::new(brotli::Decompressor::new())),
_ => None, }
}
}
pub fn index(&self) -> usize {
*self as usize
}
}
impl From<&str> for Algorithm {
fn from(s: &str) -> Self {
use unicase::UniCase;
let coding = UniCase::new(s);
if coding == UniCase::ascii("gzip") {
Algorithm::Gzip
} else if coding == UniCase::ascii("br") {
Algorithm::Brotli
} else if coding == UniCase::ascii("zstd") {
Algorithm::Zstd
} else if s.is_empty() {
Algorithm::Any
} else {
Algorithm::Other
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Action {
Noop, Compress(Algorithm),
Decompress(Algorithm),
}
fn parse_accept_encoding(accept_encoding: Option<&http::HeaderValue>, list: &mut Vec<Algorithm>) {
if let Some(ac) = accept_encoding {
if ac.as_bytes() == b"gzip" {
list.push(Algorithm::Gzip);
return;
}
match sfv::Parser::parse_list(ac.as_bytes()) {
Ok(parsed) => {
for item in parsed {
if let sfv::ListEntry::Item(i) = item {
if let Some(s) = i.bare_item.as_token() {
let algorithm = Algorithm::from(s);
if algorithm != Algorithm::Other {
list.push(Algorithm::from(s));
}
}
}
}
}
Err(e) => {
warn!("Failed to parse accept-encoding {ac:?}, {e}")
}
}
} else {
}
}
#[test]
fn test_accept_encoding_req_header() {
let mut header = RequestHeader::build("GET", b"/", None).unwrap();
let mut ac_list = Vec::new();
parse_accept_encoding(
header.headers.get(http::header::ACCEPT_ENCODING),
&mut ac_list,
);
assert!(ac_list.is_empty());
let mut ac_list = Vec::new();
header.insert_header("accept-encoding", "gzip").unwrap();
parse_accept_encoding(
header.headers.get(http::header::ACCEPT_ENCODING),
&mut ac_list,
);
assert_eq!(ac_list[0], Algorithm::Gzip);
let mut ac_list = Vec::new();
header
.insert_header("accept-encoding", "what, br, gzip")
.unwrap();
parse_accept_encoding(
header.headers.get(http::header::ACCEPT_ENCODING),
&mut ac_list,
);
assert_eq!(ac_list[0], Algorithm::Brotli);
assert_eq!(ac_list[1], Algorithm::Gzip);
}
fn depends_on_accept_encoding(
resp: &ResponseHeader,
compress_enabled: bool,
decompress_enabled: &[bool],
) -> bool {
use http::header::CONTENT_ENCODING;
(decompress_enabled.iter().any(|enabled| *enabled)
&& resp.headers.get(CONTENT_ENCODING).is_some())
|| (compress_enabled && compressible(resp))
}
#[test]
fn test_decide_on_accept_encoding() {
let mut resp = ResponseHeader::build(200, None).unwrap();
resp.insert_header("content-length", "50").unwrap();
resp.insert_header("content-type", "text/html").unwrap();
resp.insert_header("content-encoding", "gzip").unwrap();
assert!(depends_on_accept_encoding(&resp, false, &[true]));
assert!(!depends_on_accept_encoding(&resp, false, &[false]));
resp.remove_header("content-encoding");
assert!(!depends_on_accept_encoding(&resp, false, &[true]));
assert!(depends_on_accept_encoding(&resp, true, &[false]));
assert!(!depends_on_accept_encoding(&resp, false, &[false]));
resp.insert_header("content-type", "text/html+zip").unwrap();
assert!(!depends_on_accept_encoding(&resp, true, &[false]));
}
fn decide_action(resp: &ResponseHeader, accept_encoding: &[Algorithm]) -> Action {
use http::header::CONTENT_ENCODING;
let content_encoding = if let Some(ce) = resp.headers.get(CONTENT_ENCODING) {
if let Ok(ce_str) = std::str::from_utf8(ce.as_bytes()) {
Some(Algorithm::from(ce_str))
} else {
Some(Algorithm::Other)
}
} else {
None
};
if let Some(ce) = content_encoding {
if accept_encoding.contains(&ce) {
Action::Noop
} else {
Action::Decompress(ce)
}
} else if accept_encoding.is_empty() || !compressible(resp) || accept_encoding[0] == Algorithm::Any
{
Action::Noop
} else {
Action::Compress(accept_encoding[0])
}
}
#[test]
fn test_decide_action() {
use Action::*;
use Algorithm::*;
let header = ResponseHeader::build(200, None).unwrap();
assert_eq!(decide_action(&header, &[]), Noop);
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-type", "text/html").unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-encoding", "GzIp").unwrap();
header.insert_header("content-type", "text/html").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header.insert_header("content-type", "text/html").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Compress(Gzip));
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "19").unwrap();
header.insert_header("content-type", "text/html").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header
.insert_header("content-type", "text/html+zip")
.unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header.insert_header("content-type", "image/jpg").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
assert_eq!(decide_action(&header, &[]), Decompress(Gzip));
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
assert_eq!(decide_action(&header, &[Brotli]), Decompress(Gzip));
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
assert_eq!(decide_action(&header, &[Brotli, Gzip]), Noop);
}
use once_cell::sync::Lazy;
use regex::Regex;
static MIME_CHECK: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"^(?:text/|application/|font/|image/(?:x-icon|svg\+xml|nd\.microsoft\.icon)|binary/octet-stream)")
.unwrap()
});
fn compressible(resp: &ResponseHeader) -> bool {
const MIN_COMPRESS_LEN: usize = 20;
if let Some(cl) = resp.headers.get(http::header::CONTENT_LENGTH) {
if let Some(cl_num) = std::str::from_utf8(cl.as_bytes())
.ok()
.and_then(|v| v.parse::<usize>().ok())
{
if cl_num < MIN_COMPRESS_LEN {
return false;
}
}
}
if let Some(ct) = resp.headers.get(http::header::CONTENT_TYPE) {
if let Ok(ct_str) = std::str::from_utf8(ct.as_bytes()) {
if ct_str.contains("zip") {
false
} else {
MIME_CHECK.find(ct_str).is_some()
}
} else {
false }
} else {
false }
}
fn add_vary_header(resp: &mut ResponseHeader, value: &http::header::HeaderName) {
use http::header::{HeaderValue, VARY};
let already_present = resp.headers.get_all(VARY).iter().any(|existing| {
existing
.as_bytes()
.split(|b| *b == b',')
.map(|mut v| {
while let [first, rest @ ..] = v {
if first.is_ascii_whitespace() {
v = rest;
} else {
break;
}
}
while let [rest @ .., last] = v {
if last.is_ascii_whitespace() {
v = rest;
} else {
break;
}
}
v
})
.any(|v| v == b"*" || v.eq_ignore_ascii_case(value.as_ref()))
});
if !already_present {
resp.append_header(&VARY, HeaderValue::from_name(value.clone()))
.unwrap();
}
}
#[test]
fn test_add_vary_header() {
let mut header = ResponseHeader::build(200, None).unwrap();
add_vary_header(&mut header, &http::header::ACCEPT_ENCODING);
assert_eq!(
header
.headers
.get_all("Vary")
.into_iter()
.collect::<Vec<_>>(),
vec!["accept-encoding"]
);
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("Vary", "Accept-Language").unwrap();
add_vary_header(&mut header, &http::header::ACCEPT_ENCODING);
assert_eq!(
header
.headers
.get_all("Vary")
.into_iter()
.collect::<Vec<_>>(),
vec!["Accept-Language", "accept-encoding"]
);
let mut header = ResponseHeader::build(200, None).unwrap();
header
.insert_header("Vary", "Accept-Language, Accept-Encoding")
.unwrap();
add_vary_header(&mut header, &http::header::ACCEPT_ENCODING);
assert_eq!(
header
.headers
.get_all("Vary")
.into_iter()
.collect::<Vec<_>>(),
vec!["Accept-Language, Accept-Encoding"]
);
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("Vary", "*").unwrap();
add_vary_header(&mut header, &http::header::ACCEPT_ENCODING);
assert_eq!(
header
.headers
.get_all("Vary")
.into_iter()
.collect::<Vec<_>>(),
vec!["*"]
);
}
fn adjust_response_header(resp: &mut ResponseHeader, action: &Action, preserve_etag: bool) {
use http::header::{
HeaderValue, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_LENGTH, ETAG, TRANSFER_ENCODING,
};
fn set_stream_headers(resp: &mut ResponseHeader) {
resp.remove_header(&CONTENT_LENGTH);
resp.remove_header(&ACCEPT_RANGES);
resp.insert_header(&TRANSFER_ENCODING, HeaderValue::from_static("chunked"))
.unwrap();
}
fn weaken_or_clear_etag(resp: &mut ResponseHeader) {
if let Some(etag) = resp.headers.get(&ETAG) {
let etag_bytes = etag.as_bytes();
if etag_bytes.starts_with(b"W/") {
} else if etag_bytes.starts_with(b"\"") {
let weakened_etag = HeaderValue::from_bytes(&[b"W/", etag_bytes].concat())
.expect("valid header value prefixed with \"W/\" should remain valid");
resp.insert_header(&ETAG, weakened_etag)
.expect("can insert weakened etag when etag was already valid");
} else {
resp.remove_header(&ETAG);
}
}
}
match action {
Action::Noop => { }
Action::Decompress(_) => {
resp.remove_header(&CONTENT_ENCODING);
set_stream_headers(resp);
if !preserve_etag {
weaken_or_clear_etag(resp);
}
}
Action::Compress(a) => {
resp.insert_header(&CONTENT_ENCODING, HeaderValue::from_static(a.as_str()))
.unwrap();
set_stream_headers(resp);
if !preserve_etag {
weaken_or_clear_etag(resp);
}
}
}
}
#[test]
fn test_adjust_response_header() {
use Action::*;
use Algorithm::*;
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
header.insert_header("accept-ranges", "bytes").unwrap();
header.insert_header("etag", "\"abc123\"").unwrap();
adjust_response_header(&mut header, &Noop, false);
assert_eq!(
header.headers.get("content-encoding").unwrap().as_bytes(),
b"gzip"
);
assert_eq!(
header.headers.get("content-length").unwrap().as_bytes(),
b"20"
);
assert_eq!(
header.headers.get("etag").unwrap().as_bytes(),
b"\"abc123\""
);
assert!(header.headers.get("transfer-encoding").is_none());
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
header.insert_header("accept-ranges", "bytes").unwrap();
header.insert_header("etag", "\"abc123\"").unwrap();
adjust_response_header(&mut header, &Decompress(Gzip), false);
assert!(header.headers.get("content-encoding").is_none());
assert!(header.headers.get("content-length").is_none());
assert_eq!(
header.headers.get("transfer-encoding").unwrap().as_bytes(),
b"chunked"
);
assert!(header.headers.get("accept-ranges").is_none());
assert_eq!(
header.headers.get("etag").unwrap().as_bytes(),
b"W/\"abc123\""
);
header.insert_header("etag", "\"abc123\"").unwrap();
adjust_response_header(&mut header, &Decompress(Gzip), true);
assert_eq!(
header.headers.get("etag").unwrap().as_bytes(),
b"\"abc123\""
);
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header.insert_header("accept-ranges", "bytes").unwrap();
header.insert_header("etag", "abc123").unwrap();
adjust_response_header(&mut header, &Compress(Gzip), false);
assert_eq!(
header.headers.get("content-encoding").unwrap().as_bytes(),
b"gzip"
);
assert!(header.headers.get("content-length").is_none());
assert!(header.headers.get("accept-ranges").is_none());
assert_eq!(
header.headers.get("transfer-encoding").unwrap().as_bytes(),
b"chunked"
);
assert!(header.headers.get("etag").is_none());
header.insert_header("etag", "abc123").unwrap();
adjust_response_header(&mut header, &Compress(Gzip), true);
assert_eq!(header.headers.get("etag").unwrap().as_bytes(), b"abc123");
}