use std::collections::HashMap;
use std::path::PathBuf;
use axum::body::Body as AxumBody;
use axum::extract::Request;
use axum::http::{Response, StatusCode};
use axum::response::IntoResponse;
use tower::ServiceExt as TowerServiceExt;
use tower_http::services::ServeDir;
use crate::AppState;
use crate::registry::MountMode;
fn prefix_matches_segment(path: &str, prefix: &str) -> bool {
if prefix == "/" {
return true; }
if !path.starts_with(prefix) {
return false;
}
path.len() == prefix.len() || path.as_bytes()[prefix.len()] == b'/'
}
pub(crate) async fn dispatch_static(
state: &AppState,
req: Request,
path: &str,
) -> axum::response::Response {
let relative = path.trim_start_matches('/');
let has_traversal = relative.split('/').any(|seg| seg == "..");
if has_traversal {
return (StatusCode::NOT_FOUND, "Not Found").into_response();
}
let (parts, _body) = req.into_parts();
let mounts: Vec<_> = {
let inner = state.registry.inner.read().await;
inner
.mounts
.iter()
.map(|m| {
(
m.mount_path.clone(),
m.mode,
m.serve_dir.clone(),
m.cache_control.clone(),
m.error_pages.clone(),
)
})
.collect()
};
let mut indexed: Vec<_> = mounts.into_iter().collect();
indexed.sort_by_key(|a| std::cmp::Reverse(a.0.len()));
let mut best_match: Option<(String, HashMap<u16, PathBuf>)> = None;
for (mp, mode, serve_dir, cache_control, error_pages) in &indexed {
if !prefix_matches_segment(path, mp.as_str()) {
continue;
}
if best_match.is_none() {
best_match = Some((mp.clone(), error_pages.clone()));
}
let stripped_path = if mp == "/" {
relative.to_string()
} else {
let remainder = path.strip_prefix(mp.as_str()).unwrap_or("");
remainder.trim_start_matches('/').to_string()
};
let req = rebuild_request_with_path(&parts, &stripped_path);
let resp = serve_via_serve_dir(serve_dir.clone(), req, cache_control).await;
if resp.status().is_success() {
return resp;
}
if *mode == MountMode::Spa && resp.status() == StatusCode::NOT_FOUND {
let spa_req = rebuild_request(&parts);
if is_spa_qualified(&spa_req) {
return serve_spa_index(serve_dir.clone(), spa_req, cache_control).await;
}
}
if let Some(error_resp) = try_serve_error_page(error_pages, resp.status().as_u16()).await {
return error_resp;
}
}
if let Some((_, error_pages)) = &best_match
&& let Some(error_resp) =
try_serve_error_page(error_pages, StatusCode::NOT_FOUND.as_u16()).await
{
return error_resp;
}
let root_error_pages: Option<HashMap<u16, PathBuf>> = {
let inner = state.registry.inner.read().await;
inner
.mounts
.iter()
.find(|m| m.mount_path == "/")
.map(|m| m.error_pages.clone())
};
if let Some(pages) = root_error_pages
&& let Some(error_resp) = try_serve_error_page(&pages, StatusCode::NOT_FOUND.as_u16()).await
{
return error_resp;
}
(StatusCode::NOT_FOUND, "Not Found").into_response()
}
fn rebuild_request(parts: &http::request::Parts) -> Request {
let mut builder = http::Request::builder()
.method(parts.method.clone())
.uri(parts.uri.clone())
.version(parts.version);
for (k, v) in &parts.headers {
builder = builder.header(k, v);
}
builder
.extension(parts.extensions.clone())
.body(AxumBody::empty())
.expect("valid request rebuild") }
fn rebuild_request_with_path(parts: &http::request::Parts, path: &str) -> Request {
let uri = format!("/{path}");
let mut builder = http::Request::builder()
.method(parts.method.clone())
.uri(&uri)
.version(parts.version);
for (k, v) in &parts.headers {
builder = builder.header(k, v);
}
builder
.extension(parts.extensions.clone())
.body(AxumBody::empty())
.expect("valid request rebuild") }
async fn serve_via_serve_dir(
serve_dir: ServeDir,
req: Request,
cache_control: &str,
) -> axum::response::Response {
match serve_dir.oneshot(req).await {
Ok(res) => {
let (mut parts, body) = res.into_parts();
if parts.status.is_success() {
parts.headers.insert(
http::header::CACHE_CONTROL,
http::HeaderValue::from_str(cache_control)
.unwrap_or_else(|_| http::HeaderValue::from_static("public, max-age=0")),
);
}
Response::from_parts(parts, AxumBody::new(body))
}
Err(_) => (StatusCode::NOT_FOUND, "Not Found").into_response(),
}
}
fn is_spa_qualified(req: &Request) -> bool {
let method = req.method();
if method != http::Method::GET && method != http::Method::HEAD {
return false;
}
let accept = req
.headers()
.get(http::header::ACCEPT)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !accept.contains("text/html") && !accept.contains("*/*") {
return false;
}
let path = req.uri().path();
let has_extension = std::path::Path::new(path).extension().is_some();
!has_extension
}
async fn serve_spa_index(
serve_dir: ServeDir,
req: Request,
cache_control: &str,
) -> axum::response::Response {
let method = req.method().clone();
let headers = req.headers().clone();
let mut builder = axum::http::Request::builder()
.method(method)
.uri(axum::http::Uri::from_static("/"));
for (k, v) in headers.iter() {
builder = builder.header(k, v);
}
let rewrite = builder.body(AxumBody::empty()).expect("valid request"); serve_via_serve_dir(serve_dir, rewrite, cache_control).await
}
async fn try_serve_error_page(
error_pages: &HashMap<u16, PathBuf>,
status: u16,
) -> Option<axum::response::Response> {
if let Some(error_path) = error_pages.get(&status) {
if let Some(parent) = error_path.parent() {
let file_name = error_path.file_name()?.to_str()?;
let error_serve_dir = ServeDir::new(parent);
let error_req = axum::http::Request::builder()
.method(http::Method::GET)
.uri(format!("/{file_name}"))
.body(AxumBody::empty())
.expect("valid request"); let mut resp = serve_via_serve_dir(error_serve_dir, error_req, "no-cache").await;
if resp.status().is_success()
&& let Ok(code) = StatusCode::from_u16(status)
{
*resp.status_mut() = code;
}
return Some(resp);
}
}
None
}
#[cfg(test)]
mod tests {
use super::prefix_matches_segment;
#[test]
fn root_prefix_matches_everything() {
assert!(prefix_matches_segment("/", "/"));
assert!(prefix_matches_segment("/foo", "/"));
assert!(prefix_matches_segment("/foo/bar", "/"));
}
#[test]
fn exact_prefix_match() {
assert!(prefix_matches_segment("/asset", "/asset"));
assert!(prefix_matches_segment("/asset/file.txt", "/asset"));
}
#[test]
fn segment_boundary_rejects_partial_match() {
assert!(!prefix_matches_segment("/assets", "/asset"));
assert!(!prefix_matches_segment("/assets/file.txt", "/asset"));
assert!(!prefix_matches_segment("/assetx", "/asset"));
}
#[test]
fn non_matching_prefix() {
assert!(!prefix_matches_segment("/other", "/asset"));
assert!(!prefix_matches_segment("/other/file.txt", "/asset"));
}
#[test]
fn nested_prefix() {
assert!(prefix_matches_segment(
"/assets/sub/file.txt",
"/assets/sub"
));
assert!(!prefix_matches_segment(
"/assets/other/file.txt",
"/assets/sub"
));
}
}