use axum::{
body::Body,
extract::Request,
http::{header, Method, StatusCode},
response::Response,
};
use bytes::Bytes;
use rust_silos::SiloSet;
use std::{
collections::HashMap,
convert::Infallible,
future::Future,
io::Read,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::Service;
use parking_lot::RwLock;
use blake3::Hasher as Blake3;
use mime_guess::MimeGuess;
use percent_encoding::percent_decode_str;
pub struct AssetServe {
silos: Arc<SiloSet>,
prefix: Arc<str>,
precompressed: bool,
etag: bool,
etag_cache: Arc<RwLock<HashMap<String, String>>>,
}
impl AssetServe {
pub fn new(silos: SiloSet, folder: &str) -> Self {
Self {
silos: Arc::new(silos),
prefix: normalize_prefix(folder).into(),
precompressed: false,
etag: false,
etag_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn precompressed(mut self, enabled: bool) -> Self {
self.precompressed = enabled;
self
}
pub fn with_etag(mut self, enabled: bool) -> Self {
self.etag = enabled;
self
}
}
impl Service<Request> for AssetServe {
type Response = Response;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request) -> Self::Future {
let method = req.method().clone();
let raw_path = req.uri().path().to_string();
let silos = Arc::clone(&self.silos);
let prefix = Arc::clone(&self.prefix);
let precompressed = self.precompressed;
let use_etag = self.etag;
let cache = Arc::clone(&self.etag_cache);
let accept_encoding = req
.headers()
.get(header::ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let if_none_match = if use_etag {
req.headers()
.get(header::IF_NONE_MATCH)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
} else {
None
};
Box::pin(async move {
Ok(serve_file_impl(
&silos,
&prefix,
&method,
&raw_path,
precompressed,
accept_encoding.as_deref(),
if_none_match.as_deref(),
use_etag,
&cache,
)
.await)
})
}
}
async fn serve_file_impl(
silos: &SiloSet,
prefix: &str,
method: &Method,
raw_path: &str,
precompressed: bool,
accept_encoding: Option<&str>,
if_none_match: Option<&str>,
use_etag: bool,
cache: &RwLock<HashMap<String, String>>,
) -> Response {
if *method != Method::GET && *method != Method::HEAD {
return Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap();
}
let clean_rel = match clean_rel_path(raw_path) {
Some(p) => p,
None => return not_found(),
};
let logical_path = join_prefix(prefix, &clean_rel);
let (served_path, bytes, content_encoding) = match read_best_variant(
silos,
&logical_path,
precompressed,
accept_encoding,
)
.await
{
Some(v) => v,
None => return not_found(),
};
let etag_val = if use_etag {
Some(get_or_compute_etag(cache, &served_path, &bytes, silos).await)
} else {
None
};
if let (Some(etag), Some(client_etag)) = (etag_val.as_deref(), if_none_match) {
if client_etag.trim() == etag {
return not_modified(etag, precompressed);
}
}
let mut builder = Response::builder().status(StatusCode::OK);
let mime = guess_mime(&served_path);
builder = builder.header(header::CONTENT_TYPE, mime);
if let Some(enc) = content_encoding {
builder = builder.header(header::CONTENT_ENCODING, enc);
}
if precompressed {
builder = builder.header(header::VARY, "Accept-Encoding");
}
builder = builder.header(header::CACHE_CONTROL, cache_control_for(&served_path));
if let Some(etag) = etag_val.as_deref() {
builder = builder.header(header::ETAG, etag);
}
builder = builder.header(header::CONTENT_LENGTH, bytes.len().to_string());
if *method == Method::HEAD {
return builder.body(Body::empty()).unwrap();
}
builder.body(Body::from(bytes)).unwrap()
}
async fn read_best_variant(
silos: &SiloSet,
logical_path: &str,
precompressed: bool,
accept_encoding: Option<&str>,
) -> Option<(String, Bytes, Option<&'static str>)> {
if !precompressed {
let bytes = try_read_file(silos, logical_path).await?;
return Some((logical_path.to_string(), bytes, None));
}
let ae = AcceptEncoding::parse(accept_encoding);
if ae.allows("br") {
let p = format!("{logical_path}.br");
if let Some(bytes) = try_read_file(silos, &p).await {
return Some((p, bytes, Some("br")));
}
}
if ae.allows("gzip") || ae.allows("gz") {
let p = format!("{logical_path}.gz");
if let Some(bytes) = try_read_file(silos, &p).await {
return Some((p, bytes, Some("gzip")));
}
}
let bytes = try_read_file(silos, logical_path).await?;
Some((logical_path.to_string(), bytes, None))
}
async fn try_read_file(silos: &SiloSet, path: &str) -> Option<Bytes> {
let file = silos.get_file(path)?;
if file.is_embedded() {
let mut reader = file.reader().ok()?;
let mut buf = Vec::new();
reader.read_to_end(&mut buf).ok()?;
Some(Bytes::from(buf))
} else {
tokio::fs::read(file.path()).await.ok().map(Bytes::from)
}
}
fn clean_rel_path(raw_path: &str) -> Option<String> {
let stripped = raw_path.trim_start_matches('/');
let decoded = percent_decode_str(stripped).decode_utf8().ok()?;
if decoded.contains('\\') {
return None;
}
let mut out = String::with_capacity(decoded.len());
for seg in decoded.split('/') {
if seg.is_empty() {
continue;
}
if seg == "." || seg == ".." {
return None;
}
if seg.contains('\0') {
return None;
}
if !out.is_empty() {
out.push('/');
}
out.push_str(seg);
}
Some(out)
}
fn normalize_prefix(folder: &str) -> String {
if folder.is_empty() {
return String::new();
}
let trimmed = folder.trim_matches('/');
let mut s = String::with_capacity(trimmed.len() + 1);
s.push_str(trimmed);
s.push('/');
s
}
fn join_prefix(prefix: &str, rel: &str) -> String {
if prefix.is_empty() {
rel.to_string()
} else if rel.is_empty() {
prefix.trim_end_matches('/').to_string()
} else {
let mut result = String::with_capacity(prefix.len() + rel.len());
result.push_str(prefix);
result.push_str(rel);
result
}
}
fn guess_mime(path: &str) -> String {
let guess: MimeGuess = mime_guess::from_path(path);
guess
.first_or_octet_stream()
.essence_str()
.to_string()
}
fn cache_control_for(path: &str) -> &'static str {
if path.ends_with(".html") {
"no-cache"
} else {
"public, max-age=31536000, immutable"
}
}
async fn get_or_compute_etag(
cache: &RwLock<HashMap<String, String>>,
path: &str,
bytes: &Bytes,
silos: &SiloSet,
) -> String {
let file = silos.get_file(path);
let is_embedded = file.as_ref().map(|f| f.is_embedded()).unwrap_or(false);
let cache_key = if is_embedded {
path.to_string()
} else {
match file.as_ref().and_then(|f| std::fs::metadata(f.path()).ok()) {
Some(meta) => {
let mtime = meta.modified()
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| d.as_secs())
.unwrap_or(0);
format!("{path}:{mtime}:{}", meta.len())
}
None => return strong_etag(bytes),
}
};
{
let cache_read = cache.read();
if let Some(etag) = cache_read.get(&cache_key) {
return etag.clone();
}
}
let etag = strong_etag(bytes);
cache.write().insert(cache_key, etag.clone());
etag
}
fn strong_etag(bytes: &Bytes) -> String {
let mut h = Blake3::new();
h.update(bytes);
let digest = h.finalize();
format!("\"{}\"", digest.to_hex())
}
fn not_modified(etag: &str, precompressed: bool) -> Response {
let mut builder = Response::builder()
.status(StatusCode::NOT_MODIFIED)
.header(header::ETAG, etag);
if precompressed {
builder = builder.header(header::VARY, "Accept-Encoding");
}
builder.body(Body::empty()).unwrap()
}
fn not_found() -> Response {
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap()
}
#[derive(Debug, Clone)]
struct AcceptEncoding {
br_q: f32,
gzip_q: f32,
star_q: f32,
}
impl AcceptEncoding {
fn parse(h: Option<&str>) -> Self {
let mut ae = AcceptEncoding {
br_q: -1.0,
gzip_q: -1.0,
star_q: -1.0,
};
let Some(s) = h else { return ae };
for part in s.split(',') {
let part = part.trim();
if part.is_empty() {
continue;
}
let mut pieces = part.split(';').map(|x| x.trim());
let enc = pieces.next().unwrap_or("");
let mut q = 1.0f32;
for p in pieces {
if let Some(v) = p.strip_prefix("q=") {
if let Ok(val) = v.parse::<f32>() {
q = val;
}
}
}
match enc {
"br" => ae.br_q = q,
"gzip" | "gz" => ae.gzip_q = q,
"*" => ae.star_q = q,
_ => {}
}
}
ae
}
fn allows(&self, enc: &str) -> bool {
let q = match enc {
"br" => self.br_q,
"gzip" | "gz" => self.gzip_q,
_ => -1.0,
};
if q >= 0.0 {
return q > 0.0;
}
if self.star_q >= 0.0 {
return self.star_q > 0.0;
}
false
}
}