use std::{
collections::BTreeMap,
fmt::{self, Display},
fs::{self, File},
io::Read,
};
use std::{collections::VecDeque, path::Path};
use std::{fs::DirEntry, path::PathBuf};
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, HeaderValue},
middleware::Next,
response::Response,
};
use chrono::{DateTime, Duration, TimeDelta, Utc};
use regex::Regex;
use reqwest::{
StatusCode,
header::{
CACHE_CONTROL, ETAG, EXPIRES, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH, IF_RANGE,
IF_UNMODIFIED_SINCE, PRAGMA,
},
};
use tracing::{error, instrument, warn};
#[derive(Debug, Clone)]
pub struct CacheBuster {
asset_directory: String,
cache: BTreeMap<String, String>,
}
impl CacheBuster {
#[must_use]
#[instrument(skip_all)]
pub fn new(asset_directory: &str) -> Self {
Self {
asset_directory: asset_directory.to_string(),
cache: BTreeMap::new(),
}
}
#[instrument(skip_all)]
pub fn gen_cache(&mut self) {
self.cache = gen_cache(Path::new(&self.asset_directory));
}
#[must_use]
#[instrument(skip_all)]
pub fn get_file(&self, original_asset_file_path: &str) -> String {
if !original_asset_file_path.starts_with(&self.asset_directory) {
warn!(
"CacheBuster: File path does not start with asset directory: '{original_asset_file_path:?}'. Returning original path: '{original_asset_file_path:?}'."
);
return original_asset_file_path.to_string();
}
self.cache
.get(original_asset_file_path)
.cloned()
.unwrap_or_else(|| {
error!(
"CacheBuster: File not found in cache: '{original_asset_file_path:?}'. Returning original path."
);
original_asset_file_path.to_string()
})
}
#[must_use]
#[instrument(skip_all)]
pub fn get_cache(&self) -> BTreeMap<String, String> {
self.cache.clone()
}
#[instrument(skip_all)]
pub fn print_to_file(&self, output_dir: &str) {
let output_path: PathBuf = Path::new(output_dir).join("cache-buster.json");
let file: File = File::create(&output_path)
.unwrap_or_else(|_| panic!("Failed to create file: {}", output_path.display()));
serde_json::to_writer_pretty(file, &self.cache)
.unwrap_or_else(|_| panic!("Failed to write JSON to file: {}", output_path.display()));
}
#[instrument(skip_all)]
pub fn update_source_map_references(&self) {
let source_map_regex: Regex = Regex::new(r"//# sourceMappingURL=(.+\.js\.map)")
.unwrap_or_else(|_| panic!("Failed to compile sourceMappingURL regex"));
for (original_path, hashed_path) in &self.cache {
if !std::path::Path::new(original_path)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("js"))
{
continue;
}
let original_map_path: String = format!("{original_path}.map");
let hashed_map_path: &String = match self.cache.get(&original_map_path) {
Some(path) => path,
None => continue,
};
let mut content: String = fs::read_to_string(hashed_path)
.unwrap_or_else(|_| panic!("Failed to read file: {hashed_path}"));
let hashed_map_filename: &str = Path::new(hashed_map_path)
.file_name()
.and_then(|s| s.to_str())
.unwrap_or_else(|| panic!("Invalid hashed map path"));
if source_map_regex.is_match(&content) {
content = source_map_regex
.replace(
&content,
format!("//# sourceMappingURL={hashed_map_filename}"),
)
.into_owned();
fs::write(hashed_path, content)
.unwrap_or_else(|_| panic!("Failed to write file: {hashed_path}"));
}
}
}
#[instrument(skip_all)]
pub async fn never_cache_middleware(req: Request, next: Next) -> Result<Response, StatusCode> {
let mut response: Response<Body> = next.run(req).await;
remove_etag_headers(response.headers_mut());
response.headers_mut().insert(
EXPIRES,
HeaderValue::from_static("Thu, 01 Jan 1970 00:00:00 GMT"),
);
response.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_static("no-cache, no-store, must-revalidate, private, max-age=0"),
);
response
.headers_mut()
.insert(PRAGMA, HeaderValue::from_static("no-cache"));
Ok(response)
}
#[instrument(skip_all)]
pub async fn forever_cache_middleware(
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
warn!(
"CacheBuster: Forever-cacheing resource: '{}'",
req.uri().path()
);
let mut response: Response<Body> = next.run(req).await;
remove_etag_headers(response.headers_mut());
let one_year: TimeDelta = Duration::days(365);
let expires: DateTime<Utc> = Utc::now() + one_year;
response.headers_mut().insert(
EXPIRES,
HeaderValue::from_str(&expires.to_rfc2822()).unwrap(),
);
response.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_static("public, max-age=31536000, must-revalidate, immutable"),
);
Ok(response)
}
}
impl Display for CacheBuster {
#[instrument(skip_all)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut keys: Vec<&String> = self.cache.keys().collect();
keys.sort();
write!(
f,
"CacheBuster (asset directory: '{}'):",
self.asset_directory
)?;
for key in keys {
write!(f, "\n\t'{}' -> '{}'", key, self.cache.get(key).unwrap())?;
}
Ok(())
}
}
#[instrument(skip_all)]
fn gen_cache(root: &Path) -> BTreeMap<String, String> {
let mut cache: BTreeMap<String, String> = BTreeMap::new();
let mut dirs_to_visit: VecDeque<PathBuf> = VecDeque::new();
dirs_to_visit.push_back(root.to_path_buf());
while let Some(dir_path) = dirs_to_visit.pop_front() {
for entry in fs::read_dir(&dir_path)
.unwrap_or_else(|_| panic!("Failed to read directory: {}", dir_path.display()))
{
let error_msg: String = format!(
"Failed to read directory entry: {} -> {entry:?}",
dir_path.display(),
);
let entry: DirEntry = entry.expect(&error_msg);
let path: PathBuf = entry.path();
if path.is_dir() {
dirs_to_visit.push_back(path);
} else {
let original_file_path: String = path.to_string_lossy().to_string();
let new_file_path: String = generate_cache_busted_path(&path, root)
.to_string_lossy()
.to_string();
fs::rename(&original_file_path, &new_file_path).unwrap_or_else(|_| {
panic!("Failed to rename file: {original_file_path} -> {new_file_path}")
});
cache.insert(original_file_path, new_file_path);
}
}
}
cache
}
#[instrument(skip_all)]
fn generate_cache_busted_path(file_path: &Path, root: &Path) -> PathBuf {
let mut file: File = File::open(file_path).unwrap_or_else(|_| {
panic!(
"Failed to open file: {} -> {}",
root.display(),
file_path.display()
)
});
let mut contents: Vec<u8> = Vec::new();
file.read_to_end(&mut contents).unwrap_or_else(|_| {
panic!(
"Failed to read file: {} -> {}",
root.display(),
file_path.display()
)
});
let hash: String = format!("{:x}", md5::compute(contents));
let relative_path: &Path = file_path.strip_prefix(root).unwrap_or(file_path);
let parent: &Path = relative_path.parent().unwrap_or_else(|| Path::new(""));
let file_name: &str = relative_path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("");
let new_filename: String = if file_name.contains('.') {
let (name, rest) = file_name.split_once('.').unwrap();
format!("{name}.{hash}.{rest}")
} else {
format!("{file_name}.{hash}")
};
root.join(parent).join(new_filename)
}
#[instrument(skip_all)]
fn remove_etag_headers(headers: &mut HeaderMap) {
headers.remove(ETAG);
headers.remove(IF_MODIFIED_SINCE);
headers.remove(IF_MATCH);
headers.remove(IF_NONE_MATCH);
headers.remove(IF_RANGE);
headers.remove(IF_UNMODIFIED_SINCE);
}