1use anyhow::Result;
2use bytes::Bytes;
3use http_body_util::Full;
4use hyper::header::{self, HeaderName, HeaderValue};
5use hyper::server::conn::http1;
6use hyper::service::service_fn;
7use hyper::{Method, Request, Response, StatusCode};
8use hyper_util::rt::TokioIo;
9use socket2::{Domain, Protocol, Socket, Type};
10use std::net::SocketAddr;
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::net::TcpListener;
14use tracing::{debug, info, warn};
15
16use crate::routes::NanoWeb;
17
18#[derive(Clone)]
19pub struct ServeConfig {
20 pub public_dir: PathBuf,
21 pub port: u16,
22 pub dev: bool,
23 pub spa_mode: bool,
24 pub config_prefix: String,
25 pub log_requests: bool,
26}
27
28struct AppState {
29 server: Arc<NanoWeb>,
30 config: ServeConfig,
31}
32
33fn create_reuse_port_listener(addr: SocketAddr) -> Result<std::net::TcpListener> {
35 let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
36 socket.set_reuse_address(true)?;
37 #[cfg(unix)]
38 socket.set_reuse_port(true)?;
39 socket.set_nonblocking(true)?;
40 socket.bind(&addr.into())?;
41 socket.listen(8192)?; Ok(socket.into())
43}
44
45pub async fn start_server(config: ServeConfig) -> Result<()> {
46 let server = Arc::new(NanoWeb::new());
47 server.populate_routes(&config.public_dir, &config.config_prefix)?;
48
49 let state = Arc::new(AppState {
50 server,
51 config: config.clone(),
52 });
53
54 info!("Routes loaded: {}", state.server.route_count());
55
56 let addr: SocketAddr = ([0, 0, 0, 0], config.port).into();
57 let std_listener = create_reuse_port_listener(addr)?;
58 let listener = TcpListener::from_std(std_listener)?;
59
60 info!("Starting server on http://{}", addr);
61 info!("Serving directory: {:?}", config.public_dir);
62
63 let shutdown = shutdown_signal();
64 tokio::pin!(shutdown);
65
66 loop {
67 tokio::select! {
68 result = listener.accept() => {
69 let (stream, _) = result?;
70 let io = TokioIo::new(stream);
71 let state = state.clone();
72
73 tokio::spawn(async move {
74 let service = service_fn(move |req| {
75 let state = state.clone();
76 async move { handle_request(req, state) }
77 });
78
79 if let Err(e) = http1::Builder::new()
80 .keep_alive(true)
81 .pipeline_flush(true)
82 .serve_connection(io, service)
83 .await
84 {
85 debug!("Connection error: {}", e);
86 }
87 });
88 }
89 () = &mut shutdown => {
90 info!("Shutdown signal received, stopping server");
91 break;
92 }
93 }
94 }
95
96 Ok(())
97}
98
99async fn shutdown_signal() {
100 let ctrl_c = async {
101 tokio::signal::ctrl_c()
102 .await
103 .expect("failed to install Ctrl+C handler");
104 };
105
106 #[cfg(unix)]
107 let terminate = async {
108 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
109 .expect("failed to install SIGTERM handler")
110 .recv()
111 .await;
112 };
113
114 #[cfg(not(unix))]
115 let terminate = std::future::pending::<()>();
116
117 tokio::select! {
118 () = ctrl_c => {},
119 () = terminate => {},
120 }
121}
122
123type HyperResponse = Response<Full<Bytes>>;
124
125#[allow(clippy::needless_pass_by_value, clippy::unnecessary_wraps)]
126fn handle_request(
127 req: Request<hyper::body::Incoming>,
128 state: Arc<AppState>,
129) -> Result<HyperResponse, std::convert::Infallible> {
130 let is_head = req.method() == Method::HEAD;
131
132 if req.method() != Method::GET && !is_head {
133 return Ok(response(
134 StatusCode::METHOD_NOT_ALLOWED,
135 "Method Not Allowed",
136 ));
137 }
138
139 let raw_path = req.uri().path();
140
141 if raw_path == "/_health" {
143 let body = format!(
144 r#"{{"status":"ok","timestamp":"{}"}}"#,
145 httpdate::fmt_http_date(std::time::SystemTime::now())
146 );
147 return Ok(Response::builder()
148 .status(StatusCode::OK)
149 .header("content-type", "application/json")
150 .body(Full::new(Bytes::from(body)))
151 .expect("health check response"));
152 }
153
154 let path = match crate::path::validate_request_path(raw_path) {
156 Ok(sanitized) => sanitized,
157 Err(e) => {
158 warn!("Path validation failed for '{}': {}", raw_path, e);
159 return Ok(response(StatusCode::BAD_REQUEST, "Bad Request"));
160 }
161 };
162
163 if state.config.dev {
165 let _ = state.server.refresh_if_modified(
166 &path,
167 &state.config.public_dir,
168 &state.config.config_prefix,
169 );
170 }
171
172 let accept_encoding = req
173 .headers()
174 .get("accept-encoding")
175 .and_then(|h| h.to_str().ok())
176 .unwrap_or("");
177
178 let if_none_match = req
179 .headers()
180 .get("if-none-match")
181 .and_then(|h| h.to_str().ok());
182
183 let mut buf = state.server.get_response(&path, accept_encoding);
184
185 if buf.is_none() && !path.ends_with('/') {
187 let with_slash = format!("{path}/");
188 buf = state.server.get_response(&with_slash, accept_encoding);
189 }
190
191 if buf.is_none() && state.config.spa_mode {
193 debug!("SPA fallback for: {}", path);
194 buf = state.server.get_response("/", accept_encoding);
195 }
196
197 let resp = if let Some(ref b) = buf {
198 if let Some(etag) = if_none_match {
200 if etag == b.etag.as_ref() {
201 return Ok(Response::builder()
202 .status(StatusCode::NOT_MODIFIED)
203 .header("etag", b.etag.as_ref())
204 .header("cache-control", b.cache_control.as_ref())
205 .body(Full::new(Bytes::new()))
206 .expect("304 response"));
207 }
208 }
209 build_response(b, is_head)
210 } else {
211 debug!("Route not found: {path}");
212 response(StatusCode::NOT_FOUND, "Not Found")
213 };
214
215 if state.config.log_requests {
216 info!(
217 method = %req.method(),
218 path = %path,
219 status = resp.status().as_u16(),
220 "request"
221 );
222 }
223
224 Ok(resp)
225}
226
227fn response(status: StatusCode, body: &'static str) -> HyperResponse {
228 Response::builder()
229 .status(status)
230 .body(Full::new(Bytes::from_static(body.as_bytes())))
231 .expect("error response")
232}
233
234fn build_response(buf: &crate::response_buffer::ResponseBuffer, head_only: bool) -> HyperResponse {
235 let mut builder = Response::builder()
236 .status(StatusCode::OK)
237 .header(header::CONTENT_TYPE, buf.content_type.as_ref())
238 .header(header::ETAG, buf.etag.as_ref())
239 .header(header::LAST_MODIFIED, buf.last_modified.as_ref())
240 .header(header::CACHE_CONTROL, buf.cache_control.as_ref())
241 .header(header::CONTENT_LENGTH, buf.content_length.as_ref())
243 .header(
244 header::X_CONTENT_TYPE_OPTIONS,
245 HeaderValue::from_static("nosniff"),
246 )
247 .header(
248 header::X_FRAME_OPTIONS,
249 HeaderValue::from_static("SAMEORIGIN"),
250 )
251 .header(
252 header::REFERRER_POLICY,
253 HeaderValue::from_static("strict-origin-when-cross-origin"),
254 )
255 .header(
256 header::STRICT_TRANSPORT_SECURITY,
257 HeaderValue::from_static("max-age=63072000; includeSubDomains"),
258 )
259 .header(
260 HeaderName::from_static("permissions-policy"),
261 HeaderValue::from_static("camera=(), microphone=(), geolocation=()"),
262 )
263 .header(
264 HeaderName::from_static("x-dns-prefetch-control"),
265 HeaderValue::from_static("off"),
266 );
267
268 if let Some(encoding) = buf.content_encoding {
269 builder = builder.header(header::CONTENT_ENCODING, HeaderValue::from_static(encoding));
270 }
271
272 if buf.vary_encoding {
274 builder = builder.header(header::VARY, HeaderValue::from_static("Accept-Encoding"));
275 }
276
277 let body = if head_only {
278 Bytes::new()
279 } else {
280 buf.body.clone()
281 };
282
283 builder.body(Full::new(body)).expect("response body")
284}