use crate::prelude::{chrono::*, *};
#[allow(unused_imports)]
use std::cell::UnsafeCell;
use std::{borrow::Borrow, hash::Hash};
pub static HTTP_DATE: &[time::format_description::FormatItem] = time::macros::format_description!("[weekday repr:short case_sensitive:true], [day padding:zero] [month repr:short case_sensitive:true] [year padding:zero repr:full base:calendar sign:automatic] [hour repr:24 padding:zero]:[minute padding:zero]:[second padding:zero] GMT");
pub type FileCache = MokaCache<CompactString, Option<(OffsetDateTime, Bytes)>>;
pub type ResponseCache = MokaCache<UriKey, LifetimeCache<Arc<VariedResponse>>>;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[must_use]
pub struct PathQuery {
string: CompactString,
query_start: usize,
}
impl PathQuery {
#[inline]
#[must_use]
pub fn path(&self) -> &str {
&self.string[..self.query_start]
}
#[inline]
#[must_use]
pub fn query(&self) -> Option<&str> {
if self.query_start == self.string.len() {
None
} else {
Some(&self.string[self.query_start..])
}
}
pub fn truncate_query(&mut self) {
self.string.truncate(self.query_start);
}
#[inline]
#[must_use]
pub fn into_path(mut self) -> CompactString {
self.truncate_query();
self.string
}
}
impl From<&Uri> for PathQuery {
fn from(uri: &Uri) -> Self {
match uri.query() {
Some(query) => {
let mut string = CompactString::with_capacity(uri.path().len() + query.len());
string.push_str(uri.path());
string.push_str(query);
Self {
string,
query_start: uri.path().len(),
}
}
None => Self {
string: uri.path().to_compact_string(),
query_start: uri.path().len(),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum UriKey {
Path(CompactString),
PathQuery(PathQuery),
}
impl UriKey {
#[inline]
pub fn path_and_query(uri: &Uri) -> Self {
Self::PathQuery(PathQuery::from(uri))
}
#[inline]
pub fn call_all<T>(
mut self,
mut callback: impl FnMut(&Self) -> Option<T>,
) -> (Self, Option<T>) {
match callback(&self) {
Some(t) => (self, Some(t)),
None => match self {
Self::Path(_) => (self, None),
Self::PathQuery(path_query) => {
self = Self::Path(path_query.into_path());
let result = callback(&self);
(self, result)
}
},
}
}
}
#[must_use]
pub fn get_mime(extension: &str, file_contents: &[u8]) -> Mime {
mime_guess::from_ext(extension)
.first_raw()
.unwrap_or_else(|| tree_magic_mini::from_u8(file_contents))
.parse()
.unwrap_or(mime::APPLICATION_OCTET_STREAM)
}
#[must_use]
pub fn is_text(mime: &Mime) -> bool {
mime.type_() == mime::TEXT
|| mime.type_() == mime::APPLICATION
&& (mime.subtype() == mime::JAVASCRIPT
|| mime.subtype() == "graphql"
|| mime.subtype() == mime::JSON
|| mime.subtype() == mime::XML)
}
#[must_use]
#[allow(clippy::nonminimal_bool)]
pub fn do_compress(mime: &Mime) -> bool {
!(mime.type_() == mime::IMAGE && !(mime.subtype() == "svg"))
&& mime.type_() != mime::FONT
&& mime.type_() != mime::VIDEO
&& mime.type_() != mime::AUDIO
&& mime.type_() != mime::STAR
&& mime != &mime::APPLICATION_PDF
&& mime.subtype() != "zip"
&& mime.subtype() != "zstd"
&& !(mime.type_() == mime::APPLICATION
&& !(mime.subtype() == mime::JAVASCRIPT
|| mime.subtype() == "graphql"
|| mime.subtype() == mime::JSON
|| mime.subtype() == mime::XML
|| mime.subtype() == "wasm"
|| mime.subtype() == "octet-stream"))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum PreferredCompression {
#[cfg(feature = "zstd")]
Zstd,
#[cfg(feature = "br")]
Brotli,
#[cfg(feature = "gzip")]
Gzip,
None,
}
impl PreferredCompression {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
#[cfg(feature = "zstd")]
Self::Zstd => "zstd",
#[cfg(feature = "br")]
Self::Brotli => "br",
#[cfg(feature = "gzip")]
Self::Gzip => "gzip",
Self::None => "identity",
}
}
}
impl Default for PreferredCompression {
fn default() -> Self {
let list = &[
#[cfg(feature = "zstd")]
PreferredCompression::Zstd,
#[cfg(feature = "br")]
PreferredCompression::Brotli,
#[cfg(feature = "gzip")]
PreferredCompression::Gzip,
PreferredCompression::None,
];
list[0]
}
}
#[derive(Debug, Clone)]
pub struct CompressionOptions {
pub preferred: PreferredCompression,
#[cfg(feature = "zstd")]
pub zstd_level: i32,
#[cfg(feature = "br")]
pub brotli_level: u32,
#[cfg(feature = "gzip")]
pub gzip_level: u32,
}
impl CompressionOptions {
#[must_use]
pub fn oneshot() -> Self {
Self {
preferred: PreferredCompression::default(),
#[cfg(feature = "zstd")]
zstd_level: 1,
#[cfg(feature = "br")]
brotli_level: 3,
#[cfg(feature = "gzip")]
gzip_level: 1,
}
}
#[must_use]
pub fn cached() -> Self {
Self {
preferred: PreferredCompression::default(),
#[cfg(feature = "zstd")]
zstd_level: 4,
#[cfg(feature = "br")]
brotli_level: 4,
#[cfg(feature = "gzip")]
gzip_level: 2,
}
}
}
#[derive(Debug)]
#[must_use]
pub struct CompressedResponse {
identity: Response<Bytes>,
#[cfg(feature = "gzip")]
gzip: UnsafeCell<Option<Bytes>>,
#[cfg(feature = "br")]
br: UnsafeCell<Option<Bytes>>,
#[cfg(feature = "zstd")]
zstd: UnsafeCell<Option<Bytes>>,
compress: CompressPreference,
}
unsafe impl Send for CompressedResponse {}
unsafe impl Sync for CompressedResponse {}
impl CompressedResponse {
pub(crate) fn new(
mut identity: Response<Bytes>,
compress: CompressPreference,
client_cache: ClientCachePreference,
extension: &str,
) -> Self {
let headers = identity.headers_mut();
Self::set_client_cache(headers, client_cache);
Self::check_content_type(&mut identity, extension);
Self {
identity,
#[cfg(feature = "gzip")]
gzip: UnsafeCell::new(None),
#[cfg(feature = "br")]
br: UnsafeCell::new(None),
#[cfg(feature = "zstd")]
zstd: UnsafeCell::new(None),
compress,
}
}
#[cfg(feature = "gzip")]
fn gzip(&self) -> &Option<Bytes> {
unsafe { &*self.gzip.get() }
}
#[cfg(feature = "br")]
fn br(&self) -> &Option<Bytes> {
unsafe { &*self.br.get() }
}
#[cfg(feature = "zstd")]
fn zstd(&self) -> &Option<Bytes> {
unsafe { &*self.zstd.get() }
}
#[inline]
pub fn get_identity(&self) -> &Response<Bytes> {
&self.identity
}
#[allow(clippy::unused_async)] pub async fn clone_preferred<T>(
&self,
request: &Request<T>,
regular_options: &CompressionOptions,
cached_options: &CompressionOptions,
do_cache: bool,
) -> Result<Response<Bytes>, &'static str> {
if self.compress == CompressPreference::None {
return Ok(self.clone_identity_set_compression(
self.get_identity().body().clone(),
HeaderValue::from_static("identity"),
));
}
let options = if do_cache {
cached_options
} else {
regular_options
};
let values = match request
.headers()
.get("accept-encoding")
.map(HeaderValue::to_str)
.and_then(Result::ok)
{
Some(header) => utils::list_header(header),
None => Vec::new(),
};
let disable_identity = values
.iter()
.any(|v| v.value == "identity" && v.quality == 0.0);
let only_identity = values.len() == 1
&& values[0]
== utils::ValueQualitySet {
value: "identity",
quality: 1.0,
};
if only_identity {
return Ok(self.clone_identity_set_compression(
Bytes::clone(self.get_identity().body()),
HeaderValue::from_static("identity"),
));
}
#[cfg(any(feature = "gzip", feature = "br", feature = "zstd"))]
let contains = |name| values.iter().any(|v| v.value == name && v.quality != 0.0);
let mime = self
.get_identity()
.headers()
.get("content-type")
.and_then(|header| header.to_str().ok())
.and_then(|header| header.parse().ok());
debug!("Recognised mime {:?}", &mime);
let (bytes, compression) = match &mime {
Some(mime) => {
if do_compress(mime) {
#[cfg(feature = "zstd")]
let contains_zstd = contains("zstd");
#[cfg(feature = "br")]
let contains_br = contains("br");
#[cfg(feature = "gzip")]
let contains_gzip = contains("gzip");
#[allow(unused_mut)]
let mut preferred = match options.preferred.as_str() {
#[cfg(feature = "zstd")]
"zstd" if contains_zstd => {
Some((self.get_zstd(options.zstd_level).await, "zstd"))
}
#[cfg(feature = "br")]
"br" if contains_br => {
Some((self.get_br(options.brotli_level).await, "br"))
}
#[cfg(feature = "gzip")]
"gzip" if contains_gzip => {
Some((self.get_gzip(options.gzip_level).await, "gzip"))
}
_ => None,
};
#[cfg(feature = "zstd")]
if preferred.is_none() && contains_zstd {
preferred = Some((self.get_zstd(options.zstd_level).await, "zstd"));
}
#[cfg(feature = "br")]
if preferred.is_none() && contains_br {
preferred = Some((self.get_br(options.brotli_level).await, "br"));
}
#[cfg(feature = "gzip")]
if preferred.is_none() && contains_gzip {
preferred = Some((self.get_gzip(options.gzip_level).await, "gzip"));
}
preferred.unwrap_or_else(|| (self.get_identity().body(), "identity"))
} else {
debug!("Not compressing; filtered out.");
(self.get_identity().body(), "identity")
}
}
None => (self.get_identity().body(), "identity"),
};
if disable_identity && compression == "identity" {
return Err(
"identity compression is the only option, but the client refused to accept it",
);
}
Ok(self.clone_identity_set_compression(
Bytes::clone(bytes),
HeaderValue::from_static(compression),
))
}
#[inline]
fn set_client_cache(headers: &mut HeaderMap, preference: ClientCachePreference) {
let header = preference.as_header();
if let Some(h) = header {
headers.entry("cache-control").or_insert(h);
}
}
fn check_content_type(identity_response: &mut Response<Bytes>, extension: &str) {
fn add_utf_8(headers: &mut HeaderMap, mime: &Mime) {
let charset = if mime.get_param(mime::CHARSET) == Some(mime::UTF_8) {
""
} else {
"; charset=utf-8"
};
let mut header = mime.to_string();
header.push_str(charset);
let content_type =
HeaderValue::from_maybe_shared::<Bytes>(header.into_bytes().into()).unwrap();
headers.insert("content-type", content_type);
}
match identity_response.headers().get("content-type") {
Some(content_type) => {
if let Some(mime_type) = content_type
.to_str()
.ok()
.and_then(|s| s.parse::<Mime>().ok())
{
#[allow(clippy::match_same_arms)] match mime_type.get_param("charset") {
None if is_text(&mime_type) => {
add_utf_8(identity_response.headers_mut(), &mime_type);
}
Some(_) | None => {}
}
}
}
None if !identity_response.body().is_empty() => {
let short_body = identity_response
.body()
.get(..8)
.unwrap_or(identity_response.body());
let mime = get_mime(extension, short_body);
let utf_8 = is_text(&mime);
if utf_8 {
add_utf_8(identity_response.headers_mut(), &mime);
} else {
let content_type = HeaderValue::from_maybe_shared::<Bytes>(
mime.to_string().into_bytes().into(),
)
.unwrap();
identity_response
.headers_mut()
.insert("content-type", content_type);
}
}
None => {}
}
}
fn clone_identity_set_compression(
&self,
new_data: Bytes,
compression: HeaderValue,
) -> Response<Bytes> {
let response = &self.identity;
let mut builder = Response::builder()
.version(response.version())
.status(response.status());
let mut map = response.headers().clone();
if !new_data.is_empty() {
let headers = &mut map;
debug!(
"Changing content-encoding from {:?}. Has content-type {:?}",
headers.get("content-encoding"),
headers.get("content-type"),
);
headers.insert("content-encoding", compression);
}
*builder.headers_mut().unwrap() = map;
builder.body(new_data).unwrap()
}
#[cfg(feature = "gzip")]
pub async fn get_gzip(&self, level: u32) -> &Bytes {
if self.gzip().is_none() {
let bytes = self.identity.body().clone();
let buffer = threading::spawn_blocking(move || {
let mut buffer = utils::WriteableBytes::with_capacity(bytes.len() / 3 + 64);
let mut c =
flate2::write::GzEncoder::new(&mut buffer, flate2::Compression::new(level));
c.write_all(&bytes).expect("Failed to compress using gzip!");
c.finish().expect("Failed to compress using gzip!");
let buffer = buffer.into_inner();
buffer.freeze()
})
.await
.unwrap();
if self.gzip().is_none() {
unsafe { (*self.gzip.get()).replace(buffer) };
}
}
self.gzip().as_ref().unwrap()
}
#[cfg(feature = "br")]
pub async fn get_br(&self, level: u32) -> &Bytes {
if self.br().is_none() {
let bytes = self.identity.body().clone();
let buffer = threading::spawn_blocking(move || {
let mut buffer = utils::WriteableBytes::with_capacity(bytes.len() / 3 + 64);
let mut c = brotli::CompressorWriter::new(&mut buffer, 4096, level, 21);
c.write_all(&bytes)
.expect("Failed to compress using Brotli!");
c.flush().expect("Failed to compress using Brotli!");
c.into_inner();
let buffer = buffer.into_inner();
buffer.freeze()
})
.await
.unwrap();
if self.br().is_none() {
unsafe { (*self.br.get()).replace(buffer) };
}
}
self.br().as_ref().unwrap()
}
#[cfg(feature = "zstd")]
pub async fn get_zstd(&self, level: i32) -> &Bytes {
if self.zstd().is_none() {
let bytes = self.identity.body().clone();
let buffer = threading::spawn_blocking(move || {
let mut buffer = utils::WriteableBytes::with_capacity(bytes.len() / 3 + 64);
let mut encoder = zstd::Encoder::new(&mut buffer, level).unwrap();
#[cfg(feature = "zstd-multithread")]
#[allow(clippy::cast_possible_truncation)]
if let Err(err) = encoder
.multithread(std::thread::available_parallelism().map_or(8, |v| v.get() as u32))
{
error!("Failed to enable multithread support for zstd: {err}");
}
encoder
.write_all(&bytes)
.expect("Failed to compress using Zstd!");
encoder.flush().expect("Failed to compress using Zstd!");
let buffer = buffer.into_inner();
buffer.freeze()
})
.await
.unwrap();
if self.zstd().is_none() {
unsafe { (*self.zstd.get()).replace(buffer) };
}
}
self.zstd().as_ref().unwrap()
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum CompressPreference {
None,
Full,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum CachePreferenceError {
Empty,
Invalid,
ZeroDuration,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum ServerCachePreference {
None,
QueryMatters,
Full,
MaxAge(Duration),
}
impl ServerCachePreference {
#[inline]
#[must_use]
#[allow(clippy::unused_self)]
pub fn cache(self, cache_action: host::CacheAction, method: &Method) -> bool {
let of_self = match self {
Self::None => false,
Self::QueryMatters | Self::Full | Self::MaxAge(_) => true,
};
#[allow(clippy::unnested_or_patterns)] let of_response =
cache_action.into_cache() && matches!(method, &Method::GET | &Method::HEAD);
of_self && of_response
}
#[inline]
#[must_use]
pub fn query_matters(self) -> bool {
match self {
Self::None | Self::Full | Self::MaxAge(_) => false,
Self::QueryMatters => true,
}
}
}
impl str::FromStr for ServerCachePreference {
type Err = CachePreferenceError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"full" => ServerCachePreference::Full,
"query_matters" | "query-matters" | "QueryMatters" | "queryMatters" => {
ServerCachePreference::QueryMatters
}
"none" => ServerCachePreference::None,
"" => return Err(CachePreferenceError::Empty),
_ => {
if let Some(integer) = s.strip_suffix('s') {
if let Ok(integer) = integer.parse::<u64>() {
if integer == 0 {
return Err(CachePreferenceError::ZeroDuration);
}
return Ok(Self::MaxAge(integer.std_seconds()));
}
}
return Err(CachePreferenceError::Invalid);
}
})
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum ClientCachePreference {
Ignore,
None,
Changing,
Full,
MaxAge(Duration),
}
impl ClientCachePreference {
#[inline]
#[must_use]
pub fn as_header(self) -> Option<HeaderValue> {
Some(match self {
Self::Ignore => return None,
Self::None => HeaderValue::from_static("no-store"),
Self::Changing => HeaderValue::from_static("max-age=120"),
Self::Full => HeaderValue::from_static("public, max-age=604800, immutable"),
Self::MaxAge(duration) => {
let bytes = build_bytes!(
b"public, max-age=",
(duration.as_secs() + u64::from(duration.subsec_nanos() > 0))
.to_string()
.as_bytes(),
b", immutable"
);
HeaderValue::from_maybe_shared(bytes).unwrap()
}
})
}
}
impl str::FromStr for ClientCachePreference {
type Err = CachePreferenceError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Some(integer) = s.strip_suffix('s') {
if let Ok(integer) = integer.parse::<u64>() {
if integer == 0 {
return Err(CachePreferenceError::ZeroDuration);
}
return Ok(Self::MaxAge(integer.std_seconds()));
}
}
Ok(match s {
"ignore" => ClientCachePreference::Ignore,
"full" => ClientCachePreference::Full,
"changing" => ClientCachePreference::Changing,
"none" => ClientCachePreference::None,
"" => return Err(CachePreferenceError::Empty),
_ => return Err(CachePreferenceError::Invalid),
})
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub enum CacheOut<V> {
None,
Present(V),
NotInserted(V),
}
impl<V> CacheOut<V> {
#[inline]
pub fn into_option(self) -> Option<V> {
match self {
Self::None => None,
Self::Present(v) | Self::NotInserted(v) => Some(v),
}
}
pub fn map<K>(self, f: impl FnOnce(V) -> K) -> CacheOut<K> {
match self {
Self::None => CacheOut::None,
Self::NotInserted(v) => CacheOut::NotInserted(f(v)),
Self::Present(v) => CacheOut::Present(f(v)),
}
}
}
#[derive(Debug)]
pub struct MokaCache<K: Hash + Eq + Send + Sync + 'static, V: Clone + Send + Sync + 'static> {
size_limit: usize,
pub cache: moka::sync::Cache<K, V>,
}
impl<K: Hash + Eq + Send + Sync + 'static, V: Clone + Send + Sync + 'static> Default
for MokaCache<K, V>
{
fn default() -> Self {
Self {
size_limit: 4 * 1024 * 1024,
cache: moka::sync::Cache::new(1024),
}
}
}
impl<K: Hash + Eq + Send + Sync + 'static> MokaCache<K, LifetimeCache<Arc<VariedResponse>>> {
pub(crate) fn get_cache_item<Q: Hash + Eq>(
&self,
key: &Q,
) -> CacheOut<LifetimeCache<Arc<VariedResponse>>>
where
K: Borrow<Q>,
{
match self.cache.get(key) {
Some(value_and_lifetime)
if value_and_lifetime.1 .2.map_or(true, |lifetime| {
OffsetDateTime::now_utc() - value_and_lifetime.1 .0 <= lifetime
}) =>
{
CacheOut::Present(value_and_lifetime)
}
Some(_) => {
self.cache.invalidate(key);
CacheOut::None
}
None => CacheOut::None,
}
}
pub(crate) fn insert(
&self,
len: usize,
lifetime: Option<Duration>,
key: K,
response: VariedResponse,
) -> CacheOut<VariedResponse> {
if len >= self.size_limit {
return CacheOut::NotInserted(response);
}
let date_time = OffsetDateTime::now_utc();
let header = HeaderValue::from_str(
&date_time
.format(&comprash::HTTP_DATE)
.expect("failed to format datetime"),
)
.expect("We know these bytes are valid.");
self.cache
.insert(key, (Arc::new(response), (date_time, header, lifetime)));
CacheOut::None
}
pub(crate) fn insert_cache_item(
&self,
key: K,
response: VariedResponse,
) -> CacheOut<VariedResponse> {
let lifetime =
parse::CacheControl::from_headers(response.first().0.get_identity().headers())
.ok()
.as_ref()
.and_then(parse::CacheControl::as_freshness)
.map(|s| u64::from(s).std_seconds());
debug!("Inserted item to cache with lifetime {:?}", lifetime);
self.insert(
response.first().0.get_identity().body().len(),
lifetime,
key,
response,
)
}
}
pub type LifetimeCache<T> = (T, (OffsetDateTime, HeaderValue, Option<Duration>));
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn path_query_empty_query_1() {
let uri: Uri = "https://kvarn.org/index.html?".parse().unwrap();
let path_query = PathQuery::from(&uri);
assert_eq!(path_query.query(), None);
}
#[test]
fn path_query_empty_query_2() {
let uri: Uri = "https://kvarn.org/index.html?hi".parse().unwrap();
let path_query = PathQuery::from(&uri);
assert_eq!(path_query.query(), Some("hi"));
}
#[test]
fn path_query_empty_query_3() {
let uri: Uri = "https://kvarn.org/index.html??".parse().unwrap();
let path_query = PathQuery::from(&uri);
assert_eq!(path_query.query(), Some("?"));
}
#[test]
fn path_query_empty_query_4() {
let uri: Uri = "https://kvarn.org/".parse().unwrap();
let path_query = PathQuery::from(&uri);
assert_eq!(path_query.query(), None);
}
}