use axum::{
Router,
body::{Body, Bytes},
http::{Method, StatusCode, header},
response::{IntoResponse, Response},
routing::{any, get},
};
#[cfg(not(debug_assertions))]
use include_dir::File;
#[allow(unused_imports)]
use log::{debug, info, trace, warn};
use std::path::PathBuf;
use std::process::{Child, Command};
use std::sync::Arc;
pub mod frameworks;
use frameworks::Framework;
pub use include_dir;
pub use include_dir::Dir;
#[macro_export]
macro_rules! embedded_dir {
($path:tt) => {{
#[cfg(not(debug_assertions))]
{
static DIR: $crate::Dir<'static> = $crate::include_dir::include_dir!($path);
Some(&DIR)
}
#[cfg(debug_assertions)]
{
None::<&$crate::Dir<'static>>
}
}};
}
#[derive(Clone, Debug)]
pub struct ViteConfig {
pub dev_port: u16,
pub dev_host: String,
pub dir: Option<&'static Dir<'static>>,
pub not_found: String,
pub prefix: String,
pub frontend_root: Option<PathBuf>,
pub dev_command: String,
pub auto_start: bool,
pub framework: Framework,
pub dev_script: String,
pub manifest_key: String,
pub excluded_prefixes: Vec<String>,
#[cfg(debug_assertions)]
pub client: reqwest::Client,
}
impl Default for ViteConfig {
fn default() -> Self {
Self {
dev_port: 5173,
dev_host: "localhost".to_string(),
dir: None,
not_found: "404.html".to_string(),
prefix: "/static/".to_string(),
frontend_root: None,
dev_command: "npm run dev".to_string(),
auto_start: false,
framework: Framework::default(),
dev_script: "src/main.tsx".to_string(),
manifest_key: "index.html".to_string(),
excluded_prefixes: Vec::new(),
#[cfg(debug_assertions)]
client: reqwest::Client::new(),
}
}
}
impl ViteConfig {
pub fn from_env(dir: Option<&'static Dir<'static>>) -> Self {
let port = std::env::var("VITE_PORT")
.unwrap_or_else(|_| "5173".to_string())
.parse()
.unwrap_or(5173);
let prefix = std::env::var("VITE_STATIC_PREFIX").unwrap_or_else(|_| "/static/".to_string());
let frontend_root = std::env::var("VITE_ROOT").ok().map(PathBuf::from);
let dev_command =
std::env::var("VITE_DEV_CMD").unwrap_or_else(|_| "npm run dev".to_string());
let auto_start = std::env::var("VITE_AUTO_START")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
let framework = std::env::var("VITE_FRAMEWORK")
.map(|v| match v.to_lowercase().as_str() {
"react" => Framework::React,
"vue" => Framework::Vue,
"svelte" => Framework::Svelte,
_ => Framework::None,
})
.unwrap_or_default();
let dev_host = std::env::var("VITE_DEV_HOST").unwrap_or_else(|_| "localhost".to_string());
Self {
dev_port: port,
dev_host,
dir,
not_found: "404.html".to_string(),
prefix,
frontend_root,
dev_command,
auto_start,
framework,
dev_script: std::env::var("VITE_DEV_SCRIPT")
.unwrap_or_else(|_| "src/main.tsx".to_string()),
manifest_key: std::env::var("VITE_MANIFEST_KEY")
.unwrap_or_else(|_| "index.html".to_string()),
excluded_prefixes: Vec::new(),
#[cfg(debug_assertions)]
client: reqwest::Client::new(),
}
}
pub fn hmr_scripts(&self) -> String {
if !cfg!(debug_assertions) {
return String::new();
}
let prefix = self.prefix.trim_end_matches('/');
let mut scripts = String::new();
if let Some(preamble) = self.framework.preamble(prefix) {
scripts.push_str(&preamble);
scripts.push('\n');
}
scripts.push_str(&format!(
r#"<script type="module" src="{}/{}"></script>"#,
prefix, "@vite/client"
));
scripts
}
}
#[derive(Clone, Default, Debug)]
pub struct EntryAssets {
pub script: String,
pub stylesheets: Vec<String>,
}
impl ViteConfig {
pub fn entry_assets(&self) -> EntryAssets {
self.entry_assets_for(&self.manifest_key.clone())
}
#[allow(unused)]
pub fn entry_assets_for(&self, manifest_key: &str) -> EntryAssets {
let base = self.prefix.trim_end_matches('/');
#[cfg(not(debug_assertions))]
if let Some(dir) = self.dir {
if let Some(file) = dir.get_file(".vite/manifest.json") {
if let Some(json) = file.contents_utf8() {
return EntryAssets::from_manifest(json, base, manifest_key);
}
}
warn!(
"[axum-vite] entry_assets: dist/.vite/manifest.json not found in embedded dir. \
Add `build: {{ manifest: true }}` to vite.config and rebuild the frontend. \
Falling back to dev-mode paths — assets will 404 in production."
);
}
#[cfg(not(debug_assertions))]
let _ = manifest_key;
EntryAssets {
script: format!("{base}/{}", self.dev_script),
stylesheets: vec![],
}
}
}
impl EntryAssets {
#[cfg(not(debug_assertions))]
fn from_manifest(json: &str, base: &str, key: &str) -> Self {
let Ok(manifest) = serde_json::from_str::<serde_json::Value>(json) else {
warn!("[axum-vite] entry_assets: failed to parse manifest.json as JSON");
return Self::default();
};
let Some(entries) = manifest.as_object() else {
return Self::default();
};
let Some(entry) = entries.get(key) else {
warn!(
"[axum-vite] entry_assets: key {:?} not found in manifest.json. \
Available keys: {}",
key,
entries.keys().cloned().collect::<Vec<_>>().join(", ")
);
return Self::default();
};
let script = entry
.get("file")
.and_then(|f: &serde_json::Value| f.as_str())
.map(|f| format!("{base}/{f}"))
.unwrap_or_default();
let stylesheets = entry
.get("css")
.and_then(|c: &serde_json::Value| c.as_array())
.into_iter()
.flatten()
.filter_map(|s: &serde_json::Value| s.as_str())
.map(|s| format!("{base}/{s}"))
.collect();
Self {
script,
stylesheets,
}
}
}
pub struct DevServerHandle {
child: Child,
}
impl std::fmt::Debug for DevServerHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DevServerHandle").finish_non_exhaustive()
}
}
impl Drop for DevServerHandle {
fn drop(&mut self) {
let _ = self.child.kill();
let _ = self.child.wait();
}
}
pub fn spawn_dev_server(config: &ViteConfig) -> std::io::Result<DevServerHandle> {
let root = config.frontend_root.as_ref().ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"frontend_root must be set to spawn dev server",
)
})?;
info!(
"[axum-vite] spawning dev server in {:?} (`{}`)",
root, config.dev_command
);
#[cfg(unix)]
{
Command::new("sh")
.arg("-c")
.arg(format!("exec {}", config.dev_command))
.current_dir(root)
.spawn()
.map(|child| DevServerHandle { child })
}
#[cfg(windows)]
{
Command::new("cmd")
.arg("/C")
.arg(&config.dev_command)
.current_dir(root)
.spawn()
.map(|child| DevServerHandle { child })
}
}
pub fn router<S>(config: ViteConfig) -> Router<S>
where
S: Clone + Send + Sync + 'static,
{
let config = Arc::new(config);
Router::new().route(
"/{*path}",
any(move |
method: Method,
axum::extract::Path(path): axum::extract::Path<String>,
uri: axum::http::Uri,
headers: axum::http::HeaderMap,
body: Bytes,
| {
let config = config.clone();
let full_path = match uri.query() {
Some(q) => format!("{}?{}", path, q),
None => path,
};
async move { serve_asset(Some(full_path), None, headers, method, body, config).await }
}),
)
}
pub fn spa_router<S>(config: ViteConfig) -> Router<S>
where
S: Clone + Send + Sync + 'static,
{
let static_prefix = format!("/{}", config.prefix.trim_matches('/'));
let config = Arc::new(config);
let c1 = config.clone();
let c2 = config.clone();
let c3 = config.clone();
let c_mw = config;
let mut r = Router::new()
.route(
"/",
get(move || {
let c = c1.clone();
async move { _serve_index(c).await }
}),
)
.route(
"/{*path}",
any(
move |method: Method,
axum::extract::Path(path): axum::extract::Path<String>,
uri: axum::http::Uri,
headers: axum::http::HeaderMap,
body: Bytes| {
let c = c2.clone();
let full_path = match uri.query() {
Some(q) => format!("{}?{}", path, q),
None => path,
};
async move { _serve_spa_catchall(full_path, headers, method, body, c).await }
},
),
);
if static_prefix != "/" {
r = r.nest(&static_prefix, router((*c3).clone()));
}
r.layer(axum::middleware::from_fn_with_state(
c_mw,
hmr_injection_middleware,
))
}
pub async fn serve_index(config: ViteConfig) -> impl IntoResponse {
_serve_index(Arc::new(config)).await
}
#[cfg(debug_assertions)]
async fn _serve_index(config: Arc<ViteConfig>) -> Response {
let headers = axum::http::HeaderMap::new();
proxy_to_vite("", &config, &headers, &Method::GET, Bytes::new()).await
}
#[cfg(debug_assertions)]
async fn _serve_spa_catchall(
path: String,
headers: axum::http::HeaderMap,
method: Method,
body: Bytes,
config: Arc<ViteConfig>,
) -> Response {
let response = proxy_raw(&path, &config, &headers, &method, body).await;
let is_html_nav = method == Method::GET
&& headers
.get(axum::http::header::ACCEPT)
.and_then(|v| v.to_str().ok())
.is_some_and(|s| s.contains("text/html"));
if response.status() == StatusCode::NOT_FOUND && is_html_nav {
_serve_index(config).await
} else {
response
}
}
#[cfg(not(debug_assertions))]
async fn _serve_spa_catchall(
path: String,
headers: axum::http::HeaderMap,
method: Method,
_body: Bytes,
config: Arc<ViteConfig>,
) -> Response {
let clean = path.trim_start_matches('/');
let file_key = clean.split_once('?').map_or(clean, |(p, _)| p);
if config
.excluded_prefixes
.iter()
.any(|p| file_key.starts_with(p.as_str()))
{
return StatusCode::NOT_FOUND.into_response();
}
if let Some(dir) = config.dir {
if let Some(file) = dir.get_file(file_key) {
return serve_embedded_file(file, None);
}
}
let is_html_nav = method == Method::GET
&& headers
.get(axum::http::header::ACCEPT)
.and_then(|v| v.to_str().ok())
.is_some_and(|s| s.contains("text/html"));
if is_html_nav {
_serve_index(config).await
} else {
StatusCode::NOT_FOUND.into_response()
}
}
#[cfg(not(debug_assertions))]
async fn _serve_index(config: Arc<ViteConfig>) -> Response {
let dir = match config.dir {
Some(d) => d,
None => {
warn!(
"[axum-vite] serve_index: no embedded dir configured — pass the include_dir! output to ViteConfig::from_env"
);
return StatusCode::NOT_FOUND.into_response();
}
};
match dir.get_file("index.html") {
Some(file) => Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.header(header::CACHE_CONTROL, "no-store")
.body(Body::from(file.contents()))
.unwrap(),
None => {
warn!("[axum-vite] serve_index: index.html not found in embedded dir");
StatusCode::NOT_FOUND.into_response()
}
}
}
#[cfg(debug_assertions)]
async fn do_proxy(
url: &str,
log_path: &str,
config: &ViteConfig,
headers: &axum::http::HeaderMap,
method: &Method,
body: Bytes,
) -> Response {
trace!("[axum-vite] → /{}", log_path);
let mut request_builder = config.client.request(method.clone(), url);
for (name, value) in headers.iter() {
if name != axum::http::header::HOST && name != axum::http::header::ACCEPT_ENCODING {
request_builder = request_builder.header(name, value);
}
}
match request_builder.body(body).send().await {
Ok(resp) => {
let mut builder = Response::builder().status(resp.status());
for (name, value) in resp.headers().iter() {
if name != header::TRANSFER_ENCODING && name != header::CONTENT_LENGTH {
builder = builder.header(name, value);
}
}
builder
.body(Body::from(resp.bytes().await.unwrap_or_default()))
.unwrap()
}
Err(_) => {
warn!("[axum-vite] dev server unreachable at {}", url);
Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Body::from("Vite dev server unreachable"))
.unwrap()
}
}
}
#[cfg(debug_assertions)]
async fn proxy_to_vite(
path_str: &str,
config: &ViteConfig,
headers: &axum::http::HeaderMap,
method: &Method,
body: Bytes,
) -> Response {
let prefix = config.prefix.trim_matches('/');
let clean = path_str.trim_start_matches('/');
let url = if prefix.is_empty() {
format!("http://{}:{}/{}", config.dev_host, config.dev_port, clean)
} else {
format!(
"http://{}:{}/{}/{}",
config.dev_host, config.dev_port, prefix, clean
)
};
do_proxy(&url, path_str, config, headers, method, body).await
}
#[cfg(debug_assertions)]
async fn proxy_raw(
path_str: &str,
config: &ViteConfig,
headers: &axum::http::HeaderMap,
method: &Method,
body: Bytes,
) -> Response {
let clean = path_str.trim_start_matches('/');
let url = format!("http://{}:{}/{}", config.dev_host, config.dev_port, clean);
do_proxy(&url, path_str, config, headers, method, body).await
}
#[cfg(debug_assertions)]
pub async fn serve_asset(
path: Option<String>,
_mime_type: Option<&str>,
headers: axum::http::HeaderMap,
method: Method,
body: Bytes,
config: Arc<ViteConfig>,
) -> impl IntoResponse {
match path {
Some(path_str) => proxy_to_vite(&path_str, &config, &headers, &method, body)
.await
.into_response(),
None => (StatusCode::NOT_FOUND, "Not Found").into_response(),
}
}
#[cfg(debug_assertions)]
pub async fn hmr_injection_middleware(
axum::extract::State(config): axum::extract::State<Arc<ViteConfig>>,
request: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
let response = next.run(request).await;
if let Some(content_type) = response.headers().get(header::CONTENT_TYPE)
&& content_type.to_str().unwrap_or("").contains("text/html")
{
let hmr_scripts = config.hmr_scripts();
if hmr_scripts.is_empty() {
return response;
}
let (parts, body) = response.into_parts();
let bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(b) => b,
Err(_) => return Response::from_parts(parts, Body::empty()),
};
if let Ok(mut html) = String::from_utf8(bytes.to_vec()) {
if html.contains("@vite/client") {
return Response::from_parts(parts, Body::from(html));
}
if let Some(pos) = html.find("</head>") {
html.insert_str(pos, &format!("\n{}\n", hmr_scripts));
} else if let Some(pos) = html.find("</body>") {
html.insert_str(pos, &format!("\n{}\n", hmr_scripts));
} else {
html.push_str(&format!("\n{}\n", hmr_scripts));
}
let mut res = Response::from_parts(parts, Body::from(html.clone()));
res.headers_mut().insert(
header::CONTENT_LENGTH,
axum::http::HeaderValue::from(html.len()),
);
return res;
} else {
return Response::from_parts(parts, Body::from(bytes));
}
}
response
}
#[cfg(not(debug_assertions))]
fn serve_embedded_file(file: &'static File<'static>, mime_type: Option<&str>) -> Response {
let path_buf = PathBuf::from(file.path());
let resolved_mime = match mime_type {
Some(m) => m.to_string(),
None => mime_guess::from_path(&path_buf)
.first_or_octet_stream()
.to_string(),
};
let cache_header = if resolved_mime.contains("text/html") {
"no-store"
} else if path_buf
.file_name()
.and_then(|n| n.to_str())
.is_some_and(|n| matches!(n, "sw.js" | "service-worker.js" | "service-worker.ts"))
{
"no-store"
} else if resolved_mime.contains("manifest")
|| path_buf
.extension()
.and_then(|e| e.to_str())
.is_some_and(|e| e == "webmanifest")
{
"public, max-age=86400"
} else {
if path_buf.components().any(|c| c.as_os_str() == "assets") {
"public, max-age=31536000, immutable"
} else {
"public, no-cache"
}
};
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, resolved_mime)
.header(header::CACHE_CONTROL, cache_header)
.body(Body::from(file.contents()))
.unwrap()
}
#[cfg(not(debug_assertions))]
pub async fn hmr_injection_middleware(
axum::extract::State(_config): axum::extract::State<Arc<ViteConfig>>,
request: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
next.run(request).await
}
#[cfg(not(debug_assertions))]
pub async fn serve_asset(
path: Option<String>,
mime_type: Option<&str>,
_headers: axum::http::HeaderMap,
_method: Method,
_body: Bytes,
config: Arc<ViteConfig>,
) -> impl IntoResponse {
let serve_not_found = || {
if let Some(dir) = config.dir {
if let Some(f) = dir.get_file(&config.not_found) {
let mut res = serve_embedded_file(f, Some("text/html; charset=utf-8"));
*res.status_mut() = StatusCode::NOT_FOUND;
return res.into_response();
}
}
StatusCode::NOT_FOUND.into_response()
};
match path {
Some(path_str) => {
let clean = path_str.trim_start_matches('/');
let file_key = clean.split_once('?').map_or(clean, |(p, _)| p);
if config
.excluded_prefixes
.iter()
.any(|p| file_key.starts_with(p.as_str()))
{
return serve_not_found();
}
if let Some(dir) = config.dir {
if let Some(file) = dir.get_file(file_key) {
debug!("[axum-vite] serving /{}", clean);
return serve_embedded_file(file, mime_type).into_response();
}
}
warn!("[axum-vite] 404 /{}", clean);
serve_not_found()
}
None => serve_not_found(),
}
}
#[cfg(test)]
#[allow(clippy::useless_vec)]
mod tests {
use super::*;
use axum::{
Router,
body::Body,
http::{Request, StatusCode},
routing::get,
};
use tower::ServiceExt;
#[test]
fn default_config_values() {
let config = ViteConfig::default();
assert_eq!(config.dev_port, 5173);
assert_eq!(config.dev_host, "localhost");
assert_eq!(config.prefix, "/static/");
assert_eq!(config.not_found, "404.html");
assert!(!config.auto_start);
assert!(config.excluded_prefixes.is_empty());
assert!(config.dir.is_none());
}
#[test]
fn hmr_scripts_empty_in_release() {
let config = ViteConfig {
framework: frameworks::Framework::None,
prefix: "/static/".to_string(),
..Default::default()
};
let scripts = config.hmr_scripts();
if cfg!(debug_assertions) {
assert!(scripts.contains("@vite/client"), "missing @vite/client");
assert!(
!scripts.contains("@react-refresh"),
"unexpected react preamble"
);
} else {
assert!(scripts.is_empty(), "expected empty in release");
}
}
#[test]
fn hmr_scripts_react_preamble_in_debug() {
if !cfg!(debug_assertions) {
return;
}
let config = ViteConfig {
framework: frameworks::Framework::React,
prefix: "/static/".to_string(),
..Default::default()
};
let scripts = config.hmr_scripts();
assert!(scripts.contains("@react-refresh"), "missing react preamble");
assert!(
scripts.contains("injectIntoGlobalHook"),
"missing injectIntoGlobalHook"
);
let preamble_pos = scripts.find("@react-refresh").unwrap();
let client_pos = scripts.find("@vite/client").unwrap();
assert!(
preamble_pos < client_pos,
"preamble must precede @vite/client"
);
}
#[test]
fn hmr_scripts_prefix_interpolated() {
if !cfg!(debug_assertions) {
return;
}
let config = ViteConfig {
framework: frameworks::Framework::React,
prefix: "/assets/".to_string(),
..Default::default()
};
let scripts = config.hmr_scripts();
assert!(
scripts.contains("/assets/@react-refresh"),
"prefix not interpolated in preamble"
);
assert!(
scripts.contains("/assets/@vite/client"),
"prefix not interpolated in @vite/client"
);
}
#[test]
fn excluded_prefixes_default_empty() {
assert!(ViteConfig::default().excluded_prefixes.is_empty());
}
#[test]
fn excluded_prefixes_match_correctly() {
let excluded = vec!["templates/".to_string(), "index.html".to_string()];
let is_excluded = |path: &str| excluded.iter().any(|p| path.starts_with(p.as_str()));
assert!(is_excluded("templates/base.html"));
assert!(is_excluded("index.html"));
assert!(!is_excluded("assets/main.js"));
assert!(!is_excluded("favicon.ico"));
}
#[test]
fn spawn_dev_server_errors_without_root() {
let config = ViteConfig::default(); let err = spawn_dev_server(&config).expect_err("expected error when root is None");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn router_returns_unavailable_when_vite_not_running() {
let config = ViteConfig {
dev_port: 1,
prefix: "/static/".to_string(),
..Default::default()
};
let app: Router = Router::new().nest("/static", router(config));
let response = app
.oneshot(
Request::builder()
.uri("/static/main.js")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn router_no_prefix_returns_unavailable() {
let config = ViteConfig {
dev_port: 1,
prefix: "/".to_string(),
..Default::default()
};
let app: Router = router(config);
let response = app
.oneshot(
Request::builder()
.uri("/main.js")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn spa_router_unknown_path_passes_through_vite_response() {
let config = ViteConfig {
dev_port: 1,
prefix: "/static/".to_string(),
..Default::default()
};
let app: Router = spa_router(config);
let response = app
.oneshot(
Request::builder()
.uri("/some/spa/page")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn spa_router_post_to_unknown_path_is_not_swallowed() {
let config = ViteConfig {
dev_port: 1,
prefix: "/static/".to_string(),
..Default::default()
};
let app: Router = spa_router(config);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/missing")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_ne!(
response.status(),
StatusCode::OK,
"POST to unknown path must not return 200 (index.html swallow)"
);
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn serve_index_unavailable_when_vite_not_running() {
let config = ViteConfig {
dev_port: 1,
..Default::default()
};
let response = serve_index(config).await.into_response();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn serve_index_route_registered() {
let config = ViteConfig {
dev_port: 1,
..Default::default()
};
let app: Router = Router::new().route(
"/",
get({
let c = config.clone();
move || serve_index(c.clone())
}),
);
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[cfg(debug_assertions)]
fn make_hmr_app(framework: frameworks::Framework, prefix: &str) -> Router {
let config = Arc::new(ViteConfig {
framework,
prefix: prefix.to_string(),
..Default::default()
});
let html_handler = || async {
axum::response::Response::builder()
.status(StatusCode::OK)
.header(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(Body::from(
"<html><head><title>T</title></head><body></body></html>",
))
.unwrap()
};
let js_handler = || async {
axum::response::Response::builder()
.status(StatusCode::OK)
.header(axum::http::header::CONTENT_TYPE, "application/javascript")
.body(Body::from("console.log('hi')"))
.unwrap()
};
Router::new()
.route("/page", get(html_handler))
.route("/app.js", get(js_handler))
.layer(axum::middleware::from_fn_with_state(
config,
hmr_injection_middleware,
))
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn hmr_middleware_injects_before_head_close() {
let app = make_hmm_app(frameworks::Framework::React, "/static/");
let response = app
.oneshot(Request::builder().uri("/page").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let html = String::from_utf8(body.to_vec()).unwrap();
let head_pos = html.find("</head>").expect("missing </head>");
let client_pos = html
.find("@vite/client")
.expect("@vite/client not injected");
assert!(
client_pos < head_pos,
"@vite/client should be before </head>"
);
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn hmr_middleware_skips_already_injected_html() {
let config = Arc::new(ViteConfig {
framework: frameworks::Framework::React,
..Default::default()
});
let html_with_client = r#"<html><head><script type="module" src="/@vite/client"></script></head><body></body></html>"#;
let handler = move || {
let h = html_with_client;
async move {
axum::response::Response::builder()
.status(StatusCode::OK)
.header(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(Body::from(h))
.unwrap()
}
};
let app =
Router::new()
.route("/page", get(handler))
.layer(axum::middleware::from_fn_with_state(
config,
hmr_injection_middleware,
));
let response = app
.oneshot(Request::builder().uri("/page").body(Body::empty()).unwrap())
.await
.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let html = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(
html.matches("@vite/client").count(),
1,
"should not double-inject"
);
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn hmr_middleware_leaves_non_html_untouched() {
let app = make_hmm_app(frameworks::Framework::None, "/static/");
let response = app
.oneshot(
Request::builder()
.uri("/app.js")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(body.as_ref(), b"console.log('hi')");
}
#[cfg(debug_assertions)]
#[tokio::test]
async fn hmr_middleware_respects_custom_prefix_and_framework() {
let app = make_hmm_app(frameworks::Framework::React, "/assets/");
let response = app
.oneshot(Request::builder().uri("/page").body(Body::empty()).unwrap())
.await
.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let html = String::from_utf8(body.to_vec()).unwrap();
assert!(html.contains("/assets/@vite/client"), "wrong prefix used");
assert!(
html.contains("/assets/@react-refresh"),
"wrong framework/prefix"
);
}
#[cfg(debug_assertions)]
fn make_hmm_app(framework: frameworks::Framework, prefix: &str) -> Router {
make_hmr_app(framework, prefix)
}
#[test]
fn embedded_dir_returns_none_in_debug_mode() {
#[cfg(debug_assertions)]
{
let result: Option<&'static Dir<'static>> =
embedded_dir!("$CARGO_MANIFEST_DIR/nonexistent/path");
assert!(
result.is_none(),
"embedded_dir! must return None in debug builds"
);
}
#[cfg(not(debug_assertions))]
{}
}
#[test]
fn embedded_dir_type_is_option_dir() {
#[cfg(debug_assertions)]
{
let result: Option<&'static Dir<'static>> = embedded_dir!("$CARGO_MANIFEST_DIR");
assert!(result.is_none());
}
}
#[cfg(debug_assertions)]
#[test]
fn entry_assets_dev_returns_dev_script() {
let config = ViteConfig {
prefix: "/static/".to_string(),
..Default::default()
};
let entry = config.entry_assets();
assert_eq!(entry.script, "/static/src/main.tsx"); assert!(
entry.stylesheets.is_empty(),
"dev mode must not return stylesheets"
);
}
#[cfg(debug_assertions)]
#[test]
fn entry_assets_dev_respects_custom_dev_script() {
let config = ViteConfig {
prefix: "/assets/".to_string(),
dev_script: "src/index.ts".to_string(),
..Default::default()
};
let entry = config.entry_assets();
assert_eq!(entry.script, "/assets/src/index.ts");
}
#[cfg(debug_assertions)]
#[test]
fn entry_assets_dev_trims_trailing_slash_in_prefix() {
let config = ViteConfig {
prefix: "/static/".to_string(),
..Default::default()
};
let entry = config.entry_assets();
assert!(
!entry.script.contains("//"),
"double-slash in script path: {}",
entry.script
);
}
#[test]
fn entry_assets_from_manifest_happy_path() {
let json = r#"{
"index.html": {
"file": "assets/main-A1b2C3.js",
"css": ["assets/index-B2c3D4.css"]
}
}"#;
let entry = parse_manifest_for_test(json, "/static", "index.html");
assert_eq!(entry.script, "/static/assets/main-A1b2C3.js");
assert_eq!(entry.stylesheets, vec!["/static/assets/index-B2c3D4.css"]);
}
#[test]
fn entry_assets_from_manifest_multiple_css() {
let json = r#"{
"index.html": {
"file": "assets/main.js",
"css": ["assets/a.css", "assets/b.css"]
}
}"#;
let entry = parse_manifest_for_test(json, "/s", "index.html");
assert_eq!(entry.stylesheets.len(), 2);
assert_eq!(entry.stylesheets[0], "/s/assets/a.css");
assert_eq!(entry.stylesheets[1], "/s/assets/b.css");
}
#[test]
fn entry_assets_from_manifest_no_css_key() {
let json = r#"{"index.html": {"file": "assets/main.js"}}"#;
let entry = parse_manifest_for_test(json, "/static", "index.html");
assert_eq!(entry.script, "/static/assets/main.js");
assert!(entry.stylesheets.is_empty());
}
#[test]
fn entry_assets_from_manifest_key_not_found_returns_default() {
let json = r#"{"index.html": {"file": "assets/main.js"}}"#;
let entry = parse_manifest_for_test(json, "/static", "admin/index.html");
assert!(
entry.script.is_empty(),
"expected empty script on missing key"
);
assert!(entry.stylesheets.is_empty());
}
#[test]
fn entry_assets_from_manifest_invalid_json_returns_default() {
let entry = parse_manifest_for_test("not json at all {{{", "/static", "index.html");
assert!(entry.script.is_empty());
assert!(entry.stylesheets.is_empty());
}
#[test]
fn entry_assets_from_manifest_prefix_no_trailing_slash() {
let json = r#"{"index.html": {"file": "assets/main.js", "css": ["assets/a.css"]}}"#;
let entry = parse_manifest_for_test(json, "/static/", "index.html");
assert!(
!entry.script.contains("//"),
"double-slash in script: {}",
entry.script
);
assert!(
!entry.stylesheets[0].contains("//"),
"double-slash in css: {}",
entry.stylesheets[0]
);
}
fn parse_manifest_for_test(json: &str, base: &str, key: &str) -> EntryAssets {
let base = base.trim_end_matches('/');
let Ok(manifest) = serde_json::from_str::<serde_json::Value>(json) else {
return EntryAssets::default();
};
let Some(entries) = manifest.as_object() else {
return EntryAssets::default();
};
let Some(entry) = entries.get(key) else {
return EntryAssets::default();
};
let script = entry
.get("file")
.and_then(|f: &serde_json::Value| f.as_str())
.map(|f| format!("{base}/{f}"))
.unwrap_or_default();
let stylesheets = entry
.get("css")
.and_then(|c: &serde_json::Value| c.as_array())
.into_iter()
.flatten()
.filter_map(|s: &serde_json::Value| s.as_str())
.map(|s| format!("{base}/{s}"))
.collect();
EntryAssets {
script,
stylesheets,
}
}
}