use std::{
collections::HashSet,
env,
fmt::Debug,
fs::File,
io::ErrorKind,
path::{Path, PathBuf},
sync::Arc,
sync::LazyLock,
};
use brotli::enc::backward_references::BrotliEncoderMode;
use brotli::enc::BrotliEncoderParams;
use sieve_cache::ShardedSieveCache;
use std::io::Write;
use crate::compression::{ContentEncoding, MatchedFile};
use crate::SerdirError;
type CacheKey = crate::FileInfo;
const BUF_SIZE: usize = 8192;
static DEFAULT_TEXT_TYPES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"html", "htm", "xhtml", "css", "js", "json", "xml", "svg", "txt", "csv", "tsv", "md",
])
});
pub(crate) struct BrotliCache {
tempdir: PathBuf,
cache: ShardedSieveCache<CacheKey, MatchedFile>,
params: BrotliEncoderParams,
supported_extensions: HashSet<&'static str>,
max_file_size: u64,
}
impl From<crate::compression::CachedCompression> for BrotliCache {
fn from(value: crate::compression::CachedCompression) -> Self {
let tempdir = env::temp_dir();
let cache = ShardedSieveCache::new(value.cache_size as usize)
.expect("brotli cache capacity cannot be zero");
let params = BrotliEncoderParams {
quality: i32::from(value.compression_level),
..Default::default()
};
let supported_extensions = value
.supported_extensions
.clone()
.unwrap_or_else(|| DEFAULT_TEXT_TYPES.clone());
BrotliCache {
tempdir,
cache,
params,
supported_extensions,
max_file_size: value.max_file_size,
}
}
}
impl BrotliCache {
fn detect_brotli_mode(&self, extension: &str) -> BrotliEncoderMode {
match self.supported_extensions.contains(extension) {
true => BrotliEncoderMode::BROTLI_MODE_TEXT,
false => BrotliEncoderMode::BROTLI_MODE_GENERIC,
}
}
pub(crate) async fn get(&self, path: &Path) -> Result<MatchedFile, SerdirError> {
let extension: &str = path
.extension()
.and_then(|s| s.to_str())
.unwrap_or_default();
if !self.supported_extensions.contains(extension) {
return Self::wrap_orig(path, extension);
}
let file_info = crate::FileInfo::for_path(path)?;
if file_info.len() > self.max_file_size {
return Self::wrap_orig(path, extension);
}
let matched: Option<MatchedFile> = self.cache.get(&file_info);
if let Some(f) = matched {
return Ok(f);
}
let mut params = self.params.clone();
params.mode = self.detect_brotli_mode(extension);
if let Ok(len) = usize::try_from(file_info.len()) {
params.size_hint = len;
}
let path_buf = path.to_path_buf();
let tempdir = self.tempdir.clone();
let brotli_file = match tokio::task::spawn_blocking(move || {
Self::compress_internal(&tempdir, &path_buf, params)
})
.await
.map_err(|e| {
SerdirError::CompressionError("Join error".to_string(), std::io::Error::other(e))
})? {
Ok(v) => v,
Err(e) if e.kind() == ErrorKind::StorageFull => {
self.prune_cache();
return Self::wrap_orig(path, extension);
}
Err(e) => {
let msg = format!("Brotli compression failed for {}", path.display());
return Err(SerdirError::CompressionError(msg, e));
}
};
let brotli_metadata = brotli_file.metadata()?;
let pseudo_hash = file_info.get_hash().swap_bytes();
let brotli_file_info = crate::FileInfo {
path_hash: pseudo_hash,
len: brotli_metadata.len(),
mtime: brotli_metadata.modified()?,
};
let matched = MatchedFile {
file: Arc::new(brotli_file),
file_info: brotli_file_info,
content_encoding: ContentEncoding::Brotli,
extension: extension.to_string(),
};
self.cache.insert(file_info, matched.clone());
Ok(matched)
}
fn prune_cache(&self) {
let mut entries = self.cache.entries();
entries.sort_unstable_by_key(|(_, matched)| matched.file_info.len());
let keep_count = entries.len() / 2;
let keep_keys: HashSet<CacheKey> = entries
.into_iter()
.take(keep_count)
.map(|(key, _)| key)
.collect();
self.cache.retain(|key, _| keep_keys.contains(key));
}
fn wrap_orig(path: &Path, extension: &str) -> Result<MatchedFile, SerdirError> {
let file = File::open(path)?;
let file_info = crate::FileInfo::open_file(path, &file)?;
let extension = extension.to_string();
Ok(MatchedFile {
file: Arc::new(file),
file_info,
content_encoding: ContentEncoding::Identity,
extension,
})
}
fn compress_internal(
tempdir: &Path,
path: &Path,
params: BrotliEncoderParams,
) -> Result<File, crate::IOError> {
let brotli_file: File = tempfile::tempfile_in(tempdir)?;
let mut compressor = brotli::CompressorWriter::with_params(brotli_file, BUF_SIZE, ¶ms);
let mut file = File::open(path)?;
std::io::copy(&mut file, &mut compressor)?;
compressor.flush()?;
let brotli_file = compressor.into_inner();
Ok(brotli_file)
}
}
impl Debug for BrotliCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BrotliCache")
.field("cache_size", &self.cache.capacity())
.field("compression_level", &self.params.quality)
.field("tempdir", &self.tempdir)
.field("supported_extensions", &self.supported_extensions)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{collections::HashSet, io::Read};
static CAT_PHOTO_BYTES: &[u8] = include_bytes!("test-resources/cat.jpg");
static BOOK_BYTES: &[u8] = include_bytes!("test-resources/wonderland.txt");
fn read_bytes(f: &File, decompress: bool) -> Vec<u8> {
use std::io::{Seek, SeekFrom};
let mut f = f;
f.seek(SeekFrom::Start(0))
.expect("Failed to seek to beginning");
let mut raw_bytes = Vec::new();
if decompress {
let mut input = brotli::Decompressor::new(f, 4096);
let _ = input
.read_to_end(&mut raw_bytes)
.expect("Failed to decompress file");
} else {
let mut f = f;
f.read_to_end(&mut raw_bytes).expect("Failed to read file");
}
raw_bytes
}
#[test]
fn test_brotli_cache_initialization_defaults() {
let settings = crate::compression::CachedCompression::default();
let cache = BrotliCache::from(settings);
assert_eq!(
cache.params.quality,
i32::from(crate::compression::BrotliLevel::L5)
);
assert_eq!(cache.cache.capacity(), 128);
}
#[test]
fn test_brotli_cache_initialization_custom() {
let mut extensions = HashSet::new();
extensions.insert("html");
let settings = crate::compression::CachedCompression::new()
.max_size(64)
.compression_level(crate::compression::BrotliLevel::L5)
.supported_extensions(Some(extensions.clone()));
let cache = BrotliCache::from(settings);
assert_eq!(cache.cache.capacity(), 64);
assert_eq!(cache.params.quality, 5);
assert_eq!(cache.supported_extensions, extensions);
}
#[test]
fn test_brotli_cache_build() {
let cache = BrotliCache::from(
crate::compression::CachedCompression::new()
.max_size(16)
.compression_level(crate::compression::BrotliLevel::L1),
);
assert_eq!(cache.params.quality, 1);
}
#[tokio::test]
async fn test_simple_compression() {
use std::io::Write;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.html");
let content = "<html><body><h1>Hello World</h1></body></html>";
{
let mut f = File::create(&path).unwrap();
f.write_all(content.as_bytes()).unwrap();
}
let cache = BrotliCache::from(
crate::compression::CachedCompression::new()
.compression_level(crate::compression::BrotliLevel::L1),
);
let matched = cache
.get(&path)
.await
.expect("Failed to get file from cache");
assert!(matches!(matched.content_encoding, ContentEncoding::Brotli));
let orig_info = crate::FileInfo::for_path(&path).unwrap();
assert_ne!(matched.file_info.get_hash(), orig_info.get_hash());
let delta_nanos = matched
.file_info
.mtime()
.duration_since(orig_info.mtime())
.unwrap()
.as_nanos();
assert!(delta_nanos < 10_000_000);
let decompressed = read_bytes(&matched.file, true);
assert_eq!(decompressed, content.as_bytes());
let matched2 = cache
.get(&path)
.await
.expect("Failed to get file from cache second time");
assert_eq!(matched2.file_info.mtime(), matched.file_info.mtime());
}
#[tokio::test]
async fn test_compression_levels() {
use std::io::Write;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("wonderland.txt");
{
let mut f = File::create(&path).unwrap();
f.write_all(BOOK_BYTES).unwrap();
}
let cache0 = BrotliCache::from(
crate::compression::CachedCompression::new()
.compression_level(crate::compression::BrotliLevel::L0),
);
let cache5 = BrotliCache::from(
crate::compression::CachedCompression::new()
.compression_level(crate::compression::BrotliLevel::L5),
);
let matched0 = cache0
.get(&path)
.await
.expect("Failed to get file from cache0");
let matched5 = cache5
.get(&path)
.await
.expect("Failed to get file from cache5");
assert!(matches!(matched0.content_encoding, ContentEncoding::Brotli));
assert!(matches!(matched5.content_encoding, ContentEncoding::Brotli));
let decompressed0 = read_bytes(&matched0.file, true);
assert_eq!(decompressed0, BOOK_BYTES);
let decompressed5 = read_bytes(&matched5.file, true);
assert_eq!(decompressed5, BOOK_BYTES);
let size0 = matched0.file.metadata().unwrap().len();
let size5 = matched5.file.metadata().unwrap().len();
assert!(size0 > size5);
assert_eq!(83898, size0);
assert_eq!(59317, size5);
}
#[tokio::test]
async fn test_skip_compression() {
use std::io::Write;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.jpg");
{
let mut f = File::create(&path).unwrap();
f.write_all(CAT_PHOTO_BYTES).unwrap();
}
let cache = BrotliCache::from(crate::compression::CachedCompression::default());
let matched = cache
.get(&path)
.await
.expect("Failed to get file from cache");
assert!(matches!(
matched.content_encoding,
ContentEncoding::Identity
));
let orig_info = crate::FileInfo::for_path(&path).unwrap();
assert_eq!(matched.file_info.len(), orig_info.len());
assert_eq!(matched.file_info.mtime(), orig_info.mtime());
let bytes = read_bytes(&matched.file, false);
assert_eq!(bytes, CAT_PHOTO_BYTES);
}
#[tokio::test]
async fn test_max_file_size() {
use std::io::Write;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("too_big.txt");
let content = "This is a test file that is larger than 10 bytes.";
{
let mut f = File::create(&path).unwrap();
f.write_all(content.as_bytes()).unwrap();
}
let cache =
BrotliCache::from(crate::compression::CachedCompression::new().max_file_size(10));
let matched = cache
.get(&path)
.await
.expect("Failed to get file from cache");
assert!(matches!(
matched.content_encoding,
ContentEncoding::Identity
));
let path_small = dir.path().join("small.txt");
let content_small = "small"; {
let mut f = File::create(&path_small).unwrap();
f.write_all(content_small.as_bytes()).unwrap();
}
let matched_small = cache
.get(&path_small)
.await
.expect("Failed to get small file from cache");
assert!(matches!(
matched_small.content_encoding,
ContentEncoding::Brotli
));
}
#[tokio::test]
async fn test_prune_cache_keeps_smallest_half() {
use std::io::Write;
let dir = tempfile::tempdir().unwrap();
let cache = BrotliCache::from(crate::compression::CachedCompression::new().max_size(16));
for (name, size) in [
("a.txt", 64usize),
("b.txt", 256usize),
("c.txt", 1024usize),
("d.txt", 4096usize),
] {
let path = dir.path().join(name);
let mut f = File::create(&path).unwrap();
f.write_all(&vec![b'x'; size]).unwrap();
let _ = cache.get(&path).await.unwrap();
}
let mut before = cache.cache.entries();
before.sort_unstable_by_key(|(_, matched)| matched.file_info.len());
let expected_sizes: Vec<u64> = before
.iter()
.take(before.len() / 2)
.map(|(_, matched)| matched.file_info.len())
.collect();
let max_expected_size = *expected_sizes.last().unwrap();
cache.prune_cache();
let remaining = cache.cache.entries();
assert_eq!(remaining.len(), expected_sizes.len());
assert!(remaining
.iter()
.all(|(_, matched)| matched.file_info.len() <= max_expected_size));
}
}