use async_compression::Level;
use lazy_static::lazy_static;
use rocket::{
fairing::{Fairing, Info, Kind},
http::{hyper::header::CONTENT_ENCODING, Header, MediaType},
tokio::{
io::{AsyncRead, ReadBuf},
sync::RwLock,
},
Request, Response,
};
use std::{collections::HashMap, io::Cursor, task::Poll};
use crate::{CompressionUtils, Encoding};
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub(crate) enum CachedEncoding {
Gzip,
Brotli,
}
lazy_static! {
static ref EXCLUSIONS: Vec<MediaType> = vec![
MediaType::parse_flexible("application/gzip").unwrap(),
MediaType::parse_flexible("application/zip").unwrap(),
MediaType::parse_flexible("image/*").unwrap(),
MediaType::parse_flexible("video/*").unwrap(),
MediaType::parse_flexible("application/octet-stream").unwrap(),
MediaType::parse_flexible("text/event-stream").unwrap(),
];
static ref CACHED_FILES: RwLock<HashMap<(String, CachedEncoding), &'static [u8]>> = {
let m = HashMap::new();
RwLock::new(m)
};
}
pub struct Compression {
pub level: Level,
pub excluded_content_types: Vec<MediaType>,
}
impl Compression {
pub fn fairing() -> Compression {
Compression::with_level(Level::Default)
}
pub fn with_level(level: Level) -> Compression {
Compression {
level,
excluded_content_types: EXCLUSIONS.clone(),
}
}
pub fn exlude_content_types(self, excluded_content_types: Vec<MediaType>) -> Self {
Compression {
excluded_content_types,
..self
}
}
pub fn excluded_content_types(&mut self) -> &mut Vec<MediaType> {
&mut self.excluded_content_types
}
}
#[rocket::async_trait]
impl Fairing for Compression {
fn info(&self) -> Info {
Info {
name: "Response compression",
kind: Kind::Response,
}
}
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
super::CompressionUtils::compress_response(
request,
response,
&self.excluded_content_types,
self.level,
);
}
}
#[derive(Default)]
pub struct CachedCompression {
pub cached_paths: Vec<String>,
pub cached_path_prefixes: Vec<String>,
pub cached_path_suffixes: Vec<String>,
pub excluded_path_prefixes: Vec<String>,
pub level: Option<Level>,
}
impl CachedCompression {
pub fn exact_path_fairing(cached_paths: Vec<String>) -> CachedCompression {
CachedCompression {
cached_paths,
..Default::default()
}
}
pub fn path_suffix_fairing(cached_path_suffixes: Vec<String>) -> CachedCompression {
CachedCompression {
cached_path_suffixes,
..Default::default()
}
}
pub fn path_prefix_fairing(cached_path_prefixes: Vec<String>) -> CachedCompression {
CachedCompression {
cached_path_prefixes,
..Default::default()
}
}
pub fn excluded_path_prefix_fairing(excluded_path_prefixes: Vec<String>) -> CachedCompression {
CachedCompression {
cached_path_prefixes: vec!["".to_string()],
excluded_path_prefixes,
..Default::default()
}
}
pub fn static_paths(paths: Vec<&str>) -> Vec<String> {
paths.into_iter().map(Into::into).collect()
}
}
struct ErrorBody(Option<std::io::Error>);
impl AsyncRead for ErrorBody {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
let err = match self.0.take() {
Some(err) => err,
None => std::io::Error::new(std::io::ErrorKind::Other, "ErrorBody already read"),
};
Poll::Ready(Err(err))
}
}
#[rocket::async_trait]
impl Fairing for CachedCompression {
fn info(&self) -> Info {
Info {
name: "Cached response compression",
kind: Kind::Response,
}
}
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
let path = request.uri().path().to_string();
let excluded_from_cache = self
.excluded_path_prefixes
.iter()
.any(|s| path.starts_with(s));
let cache_compressed_responses = !excluded_from_cache
&& (self.cached_paths.iter().any(|s| path.eq(s))
|| self.cached_path_suffixes.iter().any(|s| path.ends_with(s))
|| self
.cached_path_prefixes
.iter()
.any(|s| path.starts_with(s)));
if !cache_compressed_responses {
return;
}
let (accepts_gzip, accepts_br) = CompressionUtils::accepted_algorithms(request);
if !accepts_gzip && !accepts_br {
return;
}
if CompressionUtils::already_encoded(response) {
return;
}
let content_type = response.content_type();
if CompressionUtils::skip_encoding(&content_type, &EXCLUSIONS) {
return;
}
let desired_encoding = if accepts_br {
CachedEncoding::Brotli
} else {
CachedEncoding::Gzip
};
let encoding = match desired_encoding {
CachedEncoding::Gzip => Encoding::Gzip,
CachedEncoding::Brotli => Encoding::Brotli,
};
if cache_compressed_responses && (accepts_gzip || accepts_br) {
let cached_body = {
let guard = CACHED_FILES.read().await;
let body = guard.get(&(path.clone(), desired_encoding)).copied();
drop(guard);
body
};
if let Some(cached_body) = cached_body {
debug!("Found cached response for {}", path);
response.set_header(Header::new(
CONTENT_ENCODING.as_str(),
format!("{}", encoding),
));
response.set_sized_body(cached_body.len(), Cursor::new(cached_body));
return;
}
}
let body = response.body_mut().take();
let compressed_body: Vec<u8> = match CompressionUtils::compress_body(
body,
desired_encoding,
self.level.unwrap_or(Level::Default),
)
.await
{
Ok(compressed_body) => compressed_body,
Err(err) => {
error!("Failed to compress response body for {}; underlying `AsyncRead` likely failed: {}", path, err);
response.set_streamed_body(ErrorBody(Some(err)));
return;
}
};
response.set_header(Header::new(
CONTENT_ENCODING.as_str(),
format!("{}", encoding),
));
response.set_sized_body(compressed_body.len(), Cursor::new(compressed_body.clone()));
debug!("Setting cached response for {}", path);
CACHED_FILES
.write()
.await
.insert((path, desired_encoding), Vec::leak(compressed_body));
}
}