use http::{header, Extensions, HeaderMap, StatusCode, Version};
use http_body::Body;
use std::{fmt, sync::Arc};
pub trait Predicate: Clone {
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body;
fn and<Other>(self, other: Other) -> And<Self, Other>
where
Self: Sized,
Other: Predicate,
{
And {
lhs: self,
rhs: other,
}
}
}
impl<F> Predicate for F
where
F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone,
{
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
let status = response.status();
let version = response.version();
let headers = response.headers();
let extensions = response.extensions();
self(status, version, headers, extensions)
}
}
impl<T> Predicate for Option<T>
where
T: Predicate,
{
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
self.as_ref()
.map(|inner| inner.should_compress(response))
.unwrap_or(true)
}
}
#[derive(Debug, Clone, Default, Copy)]
pub struct And<Lhs, Rhs> {
lhs: Lhs,
rhs: Rhs,
}
impl<Lhs, Rhs> Predicate for And<Lhs, Rhs>
where
Lhs: Predicate,
Rhs: Predicate,
{
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
self.lhs.should_compress(response) && self.rhs.should_compress(response)
}
}
#[derive(Clone)]
pub struct DefaultPredicate(And<And<SizeAbove, NotForContentType>, NotForContentType>);
impl DefaultPredicate {
pub fn new() -> Self {
let inner = SizeAbove::new(SizeAbove::DEFAULT_MIN_SIZE)
.and(NotForContentType::GRPC)
.and(NotForContentType::IMAGES);
Self(inner)
}
}
impl Default for DefaultPredicate {
fn default() -> Self {
Self::new()
}
}
impl Predicate for DefaultPredicate {
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
self.0.should_compress(response)
}
}
#[derive(Clone, Copy, Debug)]
pub struct SizeAbove(u16);
impl SizeAbove {
const DEFAULT_MIN_SIZE: u16 = 32;
pub const fn new(min_size_bytes: u16) -> Self {
Self(min_size_bytes)
}
}
impl Default for SizeAbove {
fn default() -> Self {
Self(Self::DEFAULT_MIN_SIZE)
}
}
impl Predicate for SizeAbove {
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
let content_size = response.body().size_hint().exact().or_else(|| {
response
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|h| h.to_str().ok())
.and_then(|val| val.parse().ok())
});
match content_size {
Some(size) => size >= (self.0 as u64),
_ => true,
}
}
}
#[derive(Clone, Debug)]
pub struct NotForContentType(Str);
impl NotForContentType {
pub const GRPC: Self = Self::const_new("application/grpc");
pub const IMAGES: Self = Self::const_new("image/");
pub fn new(content_type: &str) -> Self {
Self(Str::Shared(content_type.into()))
}
pub const fn const_new(content_type: &'static str) -> Self {
Self(Str::Static(content_type))
}
}
impl Predicate for NotForContentType {
fn should_compress<B>(&self, response: &http::Response<B>) -> bool
where
B: Body,
{
let str = match &self.0 {
Str::Static(str) => *str,
Str::Shared(arc) => &*arc,
};
!content_type(response).starts_with(str)
}
}
#[derive(Clone)]
enum Str {
Static(&'static str),
Shared(Arc<str>),
}
impl fmt::Debug for Str {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Static(inner) => inner.fmt(f),
Self::Shared(inner) => inner.fmt(f),
}
}
}
fn content_type<B>(response: &http::Response<B>) -> &str {
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or_default()
}