use lrwf_core::error::Result;
use lrwf_core::http::{HttpStatus, IHttpContext};
use lrwf_core::middleware::IMiddleware;
use std::path::{Path, PathBuf};
pub struct SpaMiddleware {
root: PathBuf,
index: String,
}
impl SpaMiddleware {
pub fn new(root: impl Into<PathBuf>) -> Self {
Self {
root: resolve_spa_root(root.into()),
index: "index.html".to_string(),
}
}
pub fn with_index(root: impl Into<PathBuf>, index: impl Into<String>) -> Self {
Self {
root: root.into(),
index: index.into(),
}
}
}
#[async_trait::async_trait]
impl IMiddleware for SpaMiddleware {
async fn invoke(&self, ctx: &mut dyn IHttpContext) -> Result<()> {
let method = ctx.request().method().to_uppercase();
if method != "GET" {
return Ok(());
}
let request_path = ctx.request().path();
let file_path = self.resolve_file(request_path);
match tokio::fs::read(&file_path).await {
Ok(data) => {
ctx.response_mut().set_status(HttpStatus::OK);
ctx.response_mut()
.set_header("content-type", mime_type(&file_path));
ctx.response_mut().write_bytes(data).await?;
}
Err(_) => {
let index_path = self.root.join(&self.index);
match tokio::fs::read(&index_path).await {
Ok(data) => {
ctx.response_mut().set_status(HttpStatus::OK);
ctx.response_mut().set_header("content-type", "text/html");
ctx.response_mut().write_bytes(data).await?;
}
Err(_) => {
}
}
}
}
Ok(())
}
}
impl SpaMiddleware {
fn resolve_file(&self, request_path: &str) -> PathBuf {
let relative = request_path.trim_start_matches('/');
if relative.is_empty() {
return self.root.join(&self.index);
}
let candidate = self.root.join(relative);
match candidate.canonicalize() {
Ok(resolved) => {
let root_canonical = self
.root
.canonicalize()
.unwrap_or_else(|_| self.root.clone());
if resolved.starts_with(&root_canonical) {
resolved
} else {
self.root.join(&self.index)
}
}
Err(_) => {
if is_safe_subpath(&self.root, &candidate) {
candidate
} else {
self.root.join(&self.index)
}
}
}
}
}
fn is_safe_subpath(root: &Path, candidate: &Path) -> bool {
let normalized = normalize_path(candidate);
let root_abs = root.canonicalize().unwrap_or_else(|_| root.to_path_buf());
if normalized.is_absolute() {
return normalized.starts_with(&root_abs);
}
if let Ok(cwd) = std::env::current_dir() {
let abs_candidate = cwd.join(&normalized);
if let Ok(canon) = abs_candidate.canonicalize() {
return canon.starts_with(&root_abs);
}
}
true
}
fn normalize_path(path: &Path) -> PathBuf {
let mut parts: Vec<&std::ffi::OsStr> = Vec::new();
for component in path.components() {
match component {
std::path::Component::ParentDir => {
parts.pop();
}
std::path::Component::CurDir => {}
other => {
parts.push(other.as_os_str());
}
}
}
let mut result = PathBuf::new();
for part in parts {
result.push(part);
}
result
}
fn mime_type(path: &Path) -> &'static str {
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
match ext {
"html" | "htm" => "text/html",
"js" | "mjs" => "application/javascript",
"css" => "text/css",
"json" => "application/json",
"png" => "image/png",
"jpg" | "jpeg" => "image/jpeg",
"svg" => "image/svg+xml",
"ico" => "image/x-icon",
"wasm" => "application/wasm",
"woff" => "font/woff",
"woff2" => "font/woff2",
"ttf" => "font/ttf",
"eot" => "application/vnd.ms-fontobject",
"txt" => "text/plain",
"xml" => "application/xml",
"pdf" => "application/pdf",
"zip" => "application/zip",
_ => "application/octet-stream",
}
}
fn resolve_spa_root(root: PathBuf) -> PathBuf {
if root.is_absolute() || root.exists() {
return root;
}
if let Ok(cwd) = std::env::current_dir() {
let mut dir = Some(cwd.as_path());
while let Some(d) = dir {
if let Ok(entries) = std::fs::read_dir(d) {
for entry in entries.flatten() {
if entry.path().is_dir() {
let candidate = entry.path().join(&root);
if candidate.exists() {
return candidate;
}
}
}
}
dir = d.parent();
}
}
root
}