1use lrwf_core::error::Result;
7use lrwf_core::http::{HttpStatus, IHttpContext};
8use lrwf_core::middleware::IMiddleware;
9use std::path::{Path, PathBuf};
10
11pub struct SpaMiddleware {
18 root: PathBuf,
19 index: String,
20}
21
22impl SpaMiddleware {
23 pub fn new(root: impl Into<PathBuf>) -> Self {
30 Self {
31 root: resolve_spa_root(root.into()),
32 index: "index.html".to_string(),
33 }
34 }
35
36 pub fn with_index(root: impl Into<PathBuf>, index: impl Into<String>) -> Self {
38 Self {
39 root: root.into(),
40 index: index.into(),
41 }
42 }
43}
44
45#[async_trait::async_trait]
46impl IMiddleware for SpaMiddleware {
47 async fn invoke(&self, ctx: &mut dyn IHttpContext) -> Result<()> {
48 let method = ctx.request().method().to_uppercase();
49 if method != "GET" {
50 return Ok(());
51 }
52
53 let request_path = ctx.request().path();
54 let file_path = self.resolve_file(request_path);
55
56 match tokio::fs::read(&file_path).await {
57 Ok(data) => {
58 ctx.response_mut().set_status(HttpStatus::OK);
59 ctx.response_mut()
60 .set_header("content-type", mime_type(&file_path));
61 ctx.response_mut().write_bytes(data).await?;
62 }
63 Err(_) => {
64 let index_path = self.root.join(&self.index);
66 match tokio::fs::read(&index_path).await {
67 Ok(data) => {
68 ctx.response_mut().set_status(HttpStatus::OK);
69 ctx.response_mut().set_header("content-type", "text/html");
70 ctx.response_mut().write_bytes(data).await?;
71 }
72 Err(_) => {
73 }
75 }
76 }
77 }
78
79 Ok(())
80 }
81}
82
83impl SpaMiddleware {
84 fn resolve_file(&self, request_path: &str) -> PathBuf {
86 let relative = request_path.trim_start_matches('/');
87 if relative.is_empty() {
88 return self.root.join(&self.index);
89 }
90
91 let candidate = self.root.join(relative);
92
93 match candidate.canonicalize() {
97 Ok(resolved) => {
98 let root_canonical = self
99 .root
100 .canonicalize()
101 .unwrap_or_else(|_| self.root.clone());
102 if resolved.starts_with(&root_canonical) {
103 resolved
104 } else {
105 self.root.join(&self.index)
107 }
108 }
109 Err(_) => {
110 if is_safe_subpath(&self.root, &candidate) {
113 candidate
114 } else {
115 self.root.join(&self.index)
116 }
117 }
118 }
119 }
120}
121
122fn is_safe_subpath(root: &Path, candidate: &Path) -> bool {
125 let normalized = normalize_path(candidate);
127 let root_abs = root.canonicalize().unwrap_or_else(|_| root.to_path_buf());
128
129 if normalized.is_absolute() {
131 return normalized.starts_with(&root_abs);
132 }
133
134 if let Ok(cwd) = std::env::current_dir() {
136 let abs_candidate = cwd.join(&normalized);
137 if let Ok(canon) = abs_candidate.canonicalize() {
138 return canon.starts_with(&root_abs);
139 }
140 }
141
142 true
143}
144
145fn normalize_path(path: &Path) -> PathBuf {
147 let mut parts: Vec<&std::ffi::OsStr> = Vec::new();
148 for component in path.components() {
149 match component {
150 std::path::Component::ParentDir => {
151 parts.pop();
152 }
153 std::path::Component::CurDir => {}
154 other => {
155 parts.push(other.as_os_str());
156 }
157 }
158 }
159 let mut result = PathBuf::new();
160 for part in parts {
161 result.push(part);
162 }
163 result
164}
165
166fn mime_type(path: &Path) -> &'static str {
168 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
169 match ext {
170 "html" | "htm" => "text/html",
171 "js" | "mjs" => "application/javascript",
172 "css" => "text/css",
173 "json" => "application/json",
174 "png" => "image/png",
175 "jpg" | "jpeg" => "image/jpeg",
176 "svg" => "image/svg+xml",
177 "ico" => "image/x-icon",
178 "wasm" => "application/wasm",
179 "woff" => "font/woff",
180 "woff2" => "font/woff2",
181 "ttf" => "font/ttf",
182 "eot" => "application/vnd.ms-fontobject",
183 "txt" => "text/plain",
184 "xml" => "application/xml",
185 "pdf" => "application/pdf",
186 "zip" => "application/zip",
187 _ => "application/octet-stream",
188 }
189}
190
191fn resolve_spa_root(root: PathBuf) -> PathBuf {
198 if root.is_absolute() || root.exists() {
200 return root;
201 }
202
203 if let Ok(cwd) = std::env::current_dir() {
205 let mut dir = Some(cwd.as_path());
206 while let Some(d) = dir {
207 if let Ok(entries) = std::fs::read_dir(d) {
208 for entry in entries.flatten() {
209 if entry.path().is_dir() {
210 let candidate = entry.path().join(&root);
211 if candidate.exists() {
212 return candidate;
213 }
214 }
215 }
216 }
217 dir = d.parent();
218 }
219 }
220
221 root
222}